In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import pickle
import gc
import os
from tqdm import tqdm
from torchsummary import summary

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fc23ab71830>

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
data_transform = v2.Compose([
    v2.Resize((180, 180)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5], [0.5]),
    v2.RandomInvert() 
])

In [5]:

data = datasets.ImageFolder("data/kagglecatsanddogs_5340/PetImages", transform = data_transform)

In [6]:
torch.manual_seed(42)
train_data, test_data = random_split(data, [.7, .3])
train_loader = DataLoader(dataset= train_data, num_workers= 12, batch_size= 512, shuffle= True)
test_loader  = DataLoader(dataset= test_data,  num_workers= 12, batch_size= 512, shuffle= True)

In [7]:
class Cats_n_Dogs_net(nn.Module):
    def __init__(self, random_state):
        super(Cats_n_Dogs_net, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels=32, kernel_size = (3,3)) # -4
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2))# // 2
        
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels=64, kernel_size=(3,3)) # -2
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))# // 2
        
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = (3,3)) # -2
        self.relu3 = nn.ReLU()
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2,2))# // 2
        
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3)
        self.relu4 = nn.ReLU()
        self.maxpool4 = nn.MaxPool2d(kernel_size=(2,2))
        
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3)
        self.relu5 = nn.ReLU()
        
        torch.manual_seed(random_state)
        self.dropout1 = nn.Dropout2d(p = .5)
        self.flatten1 = nn.Flatten(1)
        
        self.linear1 = nn.LazyLinear(out_features=256)
        self.relu2 = nn.ReLU()
        
        self.dropout2 = nn.Dropout2d(p = 0.2)
        self.linear2 = nn.Linear(in_features = 256, out_features = 2)
        self.softmax1 = nn.Softmax(dim = 1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.maxpool3(x)
        
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.maxpool4(x)
        
        x = self.conv5(x)
        x = self.relu5(x)
        
        x = self.dropout1(x)
        x = self.flatten1(x)
        
        x = self.linear1(x)
        x = self.relu2(x)
        
        x = self.dropout2(x)
        x = self.linear2(x)
        x = self.softmax1(x)
        return x

In [8]:
model = Cats_n_Dogs_net(42).to(device)
summary(model, (3,180, 180))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 178, 178]             896
              ReLU-2         [-1, 32, 178, 178]               0
         MaxPool2d-3           [-1, 32, 89, 89]               0
            Conv2d-4           [-1, 64, 87, 87]          18,496
              ReLU-5           [-1, 64, 87, 87]               0
         MaxPool2d-6           [-1, 64, 43, 43]               0
            Conv2d-7          [-1, 128, 41, 41]          73,856
              ReLU-8          [-1, 128, 41, 41]               0
         MaxPool2d-9          [-1, 128, 20, 20]               0
           Conv2d-10          [-1, 256, 18, 18]         295,168
             ReLU-11          [-1, 256, 18, 18]               0
        MaxPool2d-12            [-1, 256, 9, 9]               0
           Conv2d-13            [-1, 256, 7, 7]         590,080
             ReLU-14            [-1, 25



In [9]:
def accuracy(fx, y):
    return (torch.argmax(fx, dim = 1) == y).float().sum()

In [10]:
def fit(model, train_data, test_data, loss_fn = nn.CrossEntropyLoss(), optimizer = optim.Adam, epochs = 500):
    optimizer = optimizer(model.parameters())
    loss_per_epoch = []
    acc_per_epoch = []
    for e in tqdm(range(epochs)):
        running_test_loss = running_train_loss = 0.
        test_acc = train_acc = 0.
        test_size = train_size = 0
        
        model.train()
        for i, (x_train, y_train) in enumerate(train_data):
            x_train = x_train.type(torch.float32).to(device)
            y_train = y_train.to(device)
            
            train_preds = model(x_train)
            train_loss = loss_fn(train_preds, y_train)
            running_train_loss += train_loss.item()
            train_acc += accuracy(train_preds, y_train)
            train_size += x_train.shape[0]
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        
        model.eval()
        with torch.inference_mode():   
            for j, (x_test, y_test) in enumerate(test_data):
                x_test = x_test.type(torch.float32).to(device)
                y_test = y_test.to(device)
                test_preds = model(x_test)
                running_test_loss += loss_fn(test_preds, y_test)
                test_acc += accuracy(test_preds, y_test)
                test_size += x_test.shape[0]
            print(f"Progress:\n\tTrain loss: {running_train_loss / i} | Train accuracy: {train_acc / train_size}\n\t Test loss: {running_test_loss / j} | Test accuracy: {test_acc / test_size}", end = '\r')
            

In [11]:
import warnings
warnings.simplefilter("ignore")
gc.collect()
torch.cuda.empty_cache()
fit(model, train_loader, test_loader, epochs=50)

  2%|▏         | 1/50 [00:20<17:01, 20.85s/it]

Progress:
	Train loss: 0.7111415687729331 | Train accuracy: 0.5348877310752869
	 Test loss: 0.7366393804550171 | Test accuracy: 0.5426056981086731

  4%|▍         | 2/50 [00:40<16:19, 20.40s/it]

Progress:
	Train loss: 0.6830433642162996 | Train accuracy: 0.6036345362663269
	 Test loss: 0.6914387941360474 | Test accuracy: 0.6268836259841919

  6%|▌         | 3/50 [01:01<15:59, 20.41s/it]

Progress:
	Train loss: 0.6650546999538646 | Train accuracy: 0.6278073191642761
	 Test loss: 0.6745916604995728 | Test accuracy: 0.6506201028823853

  8%|▊         | 4/50 [01:21<15:39, 20.43s/it]

Progress:
	Train loss: 0.6409470526611104 | Train accuracy: 0.6582090854644775
	 Test loss: 0.6575915217399597 | Test accuracy: 0.6650220155715942

 10%|█         | 5/50 [01:42<15:21, 20.47s/it]

Progress:
	Train loss: 0.6119371249395258 | Train accuracy: 0.6936967968940735
	 Test loss: 0.6457648277282715 | Test accuracy: 0.6882250905036926

 12%|█▏        | 6/50 [02:02<15:02, 20.51s/it]

Progress:
	Train loss: 0.598270528456744 | Train accuracy: 0.710669219493866
	 Test loss: 0.6232707500457764 | Test accuracy: 0.7118282914161682

 14%|█▍        | 7/50 [02:23<14:45, 20.59s/it]

Progress:
	Train loss: 0.5746701762956732 | Train accuracy: 0.7354134917259216
	 Test loss: 0.6006184816360474 | Test accuracy: 0.7330310940742493

 16%|█▌        | 8/50 [02:43<14:20, 20.49s/it]

Progress:
	Train loss: 0.5625975272234749 | Train accuracy: 0.7501000165939331
	 Test loss: 0.6041407585144043 | Test accuracy: 0.7319642901420593

 18%|█▊        | 9/50 [03:04<13:57, 20.42s/it]

Progress:
	Train loss: 0.5578710691017263 | Train accuracy: 0.7551288604736328
	 Test loss: 0.5794642567634583 | Test accuracy: 0.7539672255516052

 20%|██        | 10/50 [03:24<13:36, 20.40s/it]

Progress:
	Train loss: 0.5407293114592048 | Train accuracy: 0.7732442021369934
	 Test loss: 0.5650185346603394 | Test accuracy: 0.7687692046165466

 22%|██▏       | 11/50 [03:45<13:15, 20.40s/it]

Progress:
	Train loss: 0.5301258301033693 | Train accuracy: 0.7860449552536011
	 Test loss: 0.5501335859298706 | Test accuracy: 0.7862381935119629

 24%|██▍       | 12/50 [04:05<12:54, 20.39s/it]

Progress:
	Train loss: 0.5185237307758892 | Train accuracy: 0.7981598973274231
	 Test loss: 0.5548650622367859 | Test accuracy: 0.7817042469978333

 26%|██▌       | 13/50 [04:26<12:37, 20.48s/it]

Progress:
	Train loss: 0.5123467568088981 | Train accuracy: 0.8065032362937927
	 Test loss: 0.5266568660736084 | Test accuracy: 0.8130417466163635

 28%|██▊       | 14/50 [04:46<12:17, 20.50s/it]

Progress:
	Train loss: 0.5007637870662353 | Train accuracy: 0.8194754123687744
	 Test loss: 0.525129497051239 | Test accuracy: 0.8113082051277161

 30%|███       | 15/50 [05:07<12:00, 20.58s/it]

Progress:
	Train loss: 0.4895942430285847 | Train accuracy: 0.8283330798149109
	 Test loss: 0.5182669162750244 | Test accuracy: 0.8181090950965881

 32%|███▏      | 16/50 [05:27<11:35, 20.45s/it]

Progress:
	Train loss: 0.4813130375216989 | Train accuracy: 0.8387908339500427
	 Test loss: 0.5295454263687134 | Test accuracy: 0.8071743249893188

 34%|███▍      | 17/50 [05:48<11:15, 20.47s/it]

Progress:
	Train loss: 0.47728481012232166 | Train accuracy: 0.841590940952301
	 Test loss: 0.500531792640686 | Test accuracy: 0.8397119641304016

 36%|███▌      | 18/50 [06:08<10:52, 20.39s/it]

Progress:
	Train loss: 0.47028472581330466 | Train accuracy: 0.8486770987510681
	 Test loss: 0.5063269138336182 | Test accuracy: 0.831710934638977

 38%|███▊      | 19/50 [06:28<10:35, 20.49s/it]

Progress:
	Train loss: 0.4664141816251418 | Train accuracy: 0.8529630303382874
	 Test loss: 0.4923287332057953 | Test accuracy: 0.8453127145767212

 40%|████      | 20/50 [06:49<10:14, 20.50s/it]

Progress:
	Train loss: 0.45848569975179787 | Train accuracy: 0.8626207709312439
	 Test loss: 0.5038517713546753 | Test accuracy: 0.8334444761276245

 42%|████▏     | 21/50 [07:09<09:54, 20.49s/it]

Progress:
	Train loss: 0.45640787832877217 | Train accuracy: 0.8636493682861328
	 Test loss: 0.4930516183376312 | Test accuracy: 0.8465129137039185

 44%|████▍     | 22/50 [07:30<09:34, 20.53s/it]

Progress:
	Train loss: 0.4598664518664865 | Train accuracy: 0.861134946346283
	 Test loss: 0.4890144169330597 | Test accuracy: 0.8473129868507385

 46%|████▌     | 23/50 [07:51<09:16, 20.61s/it]

Progress:
	Train loss: 0.44971056194866404 | Train accuracy: 0.8710212111473083
	 Test loss: 0.5031337738037109 | Test accuracy: 0.8347780108451843

 48%|████▊     | 24/50 [08:11<08:53, 20.53s/it]

Progress:
	Train loss: 0.44508307646302614 | Train accuracy: 0.8749071359634399
	 Test loss: 0.4780619442462921 | Test accuracy: 0.8586478233337402

 50%|█████     | 25/50 [08:32<08:33, 20.54s/it]

Progress:
	Train loss: 0.43687203789458556 | Train accuracy: 0.8850791454315186
	 Test loss: 0.49023422598838806 | Test accuracy: 0.8467795848846436

 52%|█████▏    | 26/50 [08:53<08:15, 20.65s/it]

Progress:
	Train loss: 0.436979179873186 | Train accuracy: 0.8844505548477173
	 Test loss: 0.47813814878463745 | Test accuracy: 0.8606480956077576

 54%|█████▍    | 27/50 [09:13<07:54, 20.64s/it]

Progress:
	Train loss: 0.4318649462040733 | Train accuracy: 0.8888508081436157
	 Test loss: 0.4730978012084961 | Test accuracy: 0.8641152381896973

 56%|█████▌    | 28/50 [09:34<07:34, 20.65s/it]

Progress:
	Train loss: 0.4219725649146473 | Train accuracy: 0.9005657434463501
	 Test loss: 0.4726446270942688 | Test accuracy: 0.865982174873352

 58%|█████▊    | 29/50 [09:55<07:13, 20.65s/it]

Progress:
	Train loss: 0.422270251547589 | Train accuracy: 0.8980513215065002
	 Test loss: 0.4624493420124054 | Test accuracy: 0.8758501410484314

 60%|██████    | 30/50 [10:15<06:50, 20.55s/it]

Progress:
	Train loss: 0.41549651149441214 | Train accuracy: 0.9062803983688354
	 Test loss: 0.46395522356033325 | Test accuracy: 0.8746500015258789

 62%|██████▏   | 31/50 [10:36<06:30, 20.58s/it]

Progress:
	Train loss: 0.4160904691499822 | Train accuracy: 0.9050231575965881
	 Test loss: 0.4589424729347229 | Test accuracy: 0.8799840211868286

 64%|██████▍   | 32/50 [10:56<06:11, 20.63s/it]

Progress:
	Train loss: 0.40633199583081636 | Train accuracy: 0.915309488773346
	 Test loss: 0.4635559618473053 | Test accuracy: 0.8739832043647766

 66%|██████▌   | 33/50 [11:17<05:50, 20.63s/it]

Progress:
	Train loss: 0.40600840659702525 | Train accuracy: 0.9163952469825745
	 Test loss: 0.46159279346466064 | Test accuracy: 0.8786504864692688

 68%|██████▊   | 34/50 [11:37<05:28, 20.55s/it]

Progress:
	Train loss: 0.4016455227837843 | Train accuracy: 0.921595573425293
	 Test loss: 0.4534924626350403 | Test accuracy: 0.8839845657348633

 70%|███████   | 35/50 [11:58<05:07, 20.52s/it]

Progress:
	Train loss: 0.39987315763445463 | Train accuracy: 0.921938419342041
	 Test loss: 0.4545004367828369 | Test accuracy: 0.8823843598365784

 72%|███████▏  | 36/50 [12:18<04:47, 20.52s/it]

Progress:
	Train loss: 0.4013633333584842 | Train accuracy: 0.9220527410507202
	 Test loss: 0.46747952699661255 | Test accuracy: 0.868915855884552

 74%|███████▍  | 37/50 [12:39<04:26, 20.52s/it]

Progress:
	Train loss: 0.39652984720819134 | Train accuracy: 0.9246242642402649
	 Test loss: 0.4560621678829193 | Test accuracy: 0.8817175626754761

 76%|███████▌  | 38/50 [13:00<04:07, 20.60s/it]

Progress:
	Train loss: 0.39905737515758066 | Train accuracy: 0.9233099222183228
	 Test loss: 0.45262259244918823 | Test accuracy: 0.8855847716331482

 78%|███████▊  | 39/50 [13:20<03:46, 20.58s/it]

Progress:
	Train loss: 0.39189957520541024 | Train accuracy: 0.931538999080658
	 Test loss: 0.45166221261024475 | Test accuracy: 0.887451708316803

 80%|████████  | 40/50 [13:41<03:26, 20.61s/it]

Progress:
	Train loss: 0.38956483584993024 | Train accuracy: 0.931538999080658
	 Test loss: 0.4514645040035248 | Test accuracy: 0.8871849775314331

 82%|████████▏ | 41/50 [14:01<03:05, 20.60s/it]

Progress:
	Train loss: 0.3872317233506371 | Train accuracy: 0.9354820251464844
	 Test loss: 0.4631683826446533 | Test accuracy: 0.8741165995597839

 84%|████████▍ | 42/50 [14:22<02:43, 20.48s/it]

Progress:
	Train loss: 0.3897772028165705 | Train accuracy: 0.9319390058517456
	 Test loss: 0.443942666053772 | Test accuracy: 0.8938525319099426

 86%|████████▌ | 43/50 [14:41<02:21, 20.28s/it]

Progress:
	Train loss: 0.38308780771844525 | Train accuracy: 0.9393680095672607
	 Test loss: 0.4521996080875397 | Test accuracy: 0.8847846388816833

 88%|████████▊ | 44/50 [15:02<02:01, 20.33s/it]

Progress:
	Train loss: 0.3802995567812639 | Train accuracy: 0.9429110288619995
	 Test loss: 0.44500574469566345 | Test accuracy: 0.8930524587631226

 90%|█████████ | 45/50 [15:23<01:42, 20.47s/it]

Progress:
	Train loss: 0.37732454959084005 | Train accuracy: 0.9443968534469604
	 Test loss: 0.4466803967952728 | Test accuracy: 0.8927857279777527

 92%|█████████▏| 46/50 [15:43<01:21, 20.36s/it]

Progress:
	Train loss: 0.3744533105808146 | Train accuracy: 0.948111355304718
	 Test loss: 0.44999319314956665 | Test accuracy: 0.8903853893280029

 94%|█████████▍| 47/50 [16:02<00:59, 19.89s/it]

Progress:
	Train loss: 0.37116781665998344 | Train accuracy: 0.9505686163902283
	 Test loss: 0.44764310121536255 | Test accuracy: 0.8909187912940979

 96%|█████████▌| 48/50 [16:21<00:39, 19.76s/it]

Progress:
	Train loss: 0.3723902982823989 | Train accuracy: 0.9498257040977478
	 Test loss: 0.4484764635562897 | Test accuracy: 0.8909187912940979

 98%|█████████▊| 49/50 [16:41<00:19, 19.70s/it]

Progress:
	Train loss: 0.3687946989255793 | Train accuracy: 0.9549117088317871
	 Test loss: 0.43853330612182617 | Test accuracy: 0.9009201526641846

100%|██████████| 50/50 [17:00<00:00, 20.41s/it]

Progress:
	Train loss: 0.3678039689274395 | Train accuracy: 0.9564546942710876
	 Test loss: 0.4443427622318268 | Test accuracy: 0.8937191963195801




torch.save(model.state_dict(), "cnd_v2_89_acc.pt")