In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import pickle
import gc
import os
from tqdm import tqdm
from torchsummary import summary

import cnn_models

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7fd9bb541830>

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
data_transform = v2.Compose([
    v2.Resize((180, 180)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5], [0.5]),
    v2.RandomInvert() 
])

In [5]:
data = datasets.ImageFolder("/home/pathetic/Documents/torch_masters/data/kagglecatsanddogs_5340/PetImages", transform = data_transform)

In [6]:
torch.manual_seed(42)
train_data, test_data = random_split(data, [.7, .3])
train_loader = DataLoader(dataset= train_data, num_workers= 12, batch_size= 256, shuffle= True)
test_loader  = DataLoader(dataset= test_data,  num_workers= 12, batch_size= 256, shuffle= True)

In [7]:
model_1 = cnn_models.CNN_casual(42).to(device)
summary(model_1, (3,180, 180))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 178, 178]             896
              ReLU-2         [-1, 32, 178, 178]               0
         MaxPool2d-3           [-1, 32, 89, 89]               0
            Conv2d-4           [-1, 64, 87, 87]          18,496
              ReLU-5           [-1, 64, 87, 87]               0
         MaxPool2d-6           [-1, 64, 43, 43]               0
            Conv2d-7          [-1, 128, 41, 41]          73,856
              ReLU-8          [-1, 128, 41, 41]               0
         MaxPool2d-9          [-1, 128, 20, 20]               0
           Conv2d-10          [-1, 256, 18, 18]         295,168
             ReLU-11          [-1, 256, 18, 18]               0
        MaxPool2d-12            [-1, 256, 9, 9]               0
           Conv2d-13            [-1, 256, 7, 7]         590,080
             ReLU-14            [-1, 25



In [8]:
model_2 = cnn_models.CNN_nin().to(device)
summary(model_2, (3, 180, 180))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 178, 178]             896
              ReLU-2         [-1, 32, 178, 178]               0
            Conv2d-3         [-1, 32, 178, 178]           1,056
              ReLU-4         [-1, 32, 178, 178]               0
            Conv2d-5         [-1, 32, 178, 178]           1,056
              ReLU-6         [-1, 32, 178, 178]               0
         MaxPool2d-7           [-1, 32, 89, 89]               0
            Conv2d-8           [-1, 64, 87, 87]          18,496
              ReLU-9           [-1, 64, 87, 87]               0
           Conv2d-10           [-1, 64, 87, 87]           4,160
             ReLU-11           [-1, 64, 87, 87]               0
           Conv2d-12           [-1, 64, 87, 87]           4,160
             ReLU-13           [-1, 64, 87, 87]               0
        MaxPool2d-14           [-1, 64,

In [9]:
model_3 = cnn_models.CNN_casual_norm(42).to(device)
summary(model_3, (3, 180, 180))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 178, 178]             896
       BatchNorm2d-2         [-1, 32, 178, 178]              64
              ReLU-3         [-1, 32, 178, 178]               0
         MaxPool2d-4           [-1, 32, 89, 89]               0
            Conv2d-5           [-1, 64, 87, 87]          18,496
       BatchNorm2d-6           [-1, 64, 87, 87]             128
              ReLU-7           [-1, 64, 87, 87]               0
         MaxPool2d-8           [-1, 64, 43, 43]               0
            Conv2d-9          [-1, 128, 41, 41]          73,856
      BatchNorm2d-10          [-1, 128, 41, 41]             256
             ReLU-11          [-1, 128, 41, 41]               0
        MaxPool2d-12          [-1, 128, 20, 20]               0
           Conv2d-13          [-1, 256, 18, 18]         295,168
      BatchNorm2d-14          [-1, 256,

In [10]:
def accuracy(fx, y):
    return (torch.argmax(fx, dim = 1) == y).float().sum()

In [11]:
def fit(model, train_data, test_data, loss_fn = nn.CrossEntropyLoss(), optimizer = optim.Adam, epochs = 500):
    optimizer = optimizer(model.parameters())
    loss_per_epoch = []
    acc_per_epoch = []
    for e in tqdm(range(epochs)):
        running_test_loss = running_train_loss = 0.
        test_acc = train_acc = 0.
        test_size = train_size = 0
        
        model.train()
        for i, (x_train, y_train) in enumerate(train_data):
            x_train = x_train.type(torch.float32).to(device)
            y_train = y_train.to(device)
            
            train_preds = model(x_train)
            train_loss = loss_fn(train_preds, y_train)
            running_train_loss += train_loss.item()
            train_acc += accuracy(train_preds, y_train)
            train_size += x_train.shape[0]
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
        
        model.eval()
        with torch.inference_mode():   
            for j, (x_test, y_test) in enumerate(test_data):
                x_test = x_test.type(torch.float32).to(device)
                y_test = y_test.to(device)
                test_preds = model(x_test)
                running_test_loss += loss_fn(test_preds, y_test)
                test_acc += accuracy(test_preds, y_test)
                test_size += x_test.shape[0]
            print(f"Progress:\n\tTrain loss: {running_train_loss / i} | Train accuracy: {train_acc / train_size}\n\t Test loss: {running_test_loss / j} | Test accuracy: {test_acc / test_size}", end = '\r')
            

In [12]:
import warnings
warnings.simplefilter("ignore")
gc.collect()
torch.cuda.empty_cache()
fit(model_3, train_loader, test_loader, epochs=50)

  2%|▏         | 1/50 [00:22<18:28, 22.63s/it]

Progress:
	Train loss: 0.6495964439476237 | Train accuracy: 0.6378650665283203
	 Test loss: 0.6859535574913025 | Test accuracy: 0.6095479726791382

  4%|▍         | 2/50 [00:45<18:03, 22.58s/it]

Progress:
	Train loss: 0.5853381183217553 | Train accuracy: 0.7279844880104065
	 Test loss: 0.6794784665107727 | Test accuracy: 0.6050140261650085

  6%|▌         | 3/50 [01:07<17:43, 22.62s/it]

Progress:
	Train loss: 0.5488510841832441 | Train accuracy: 0.7758729457855225
	 Test loss: 0.5680448412895203 | Test accuracy: 0.756367564201355

  8%|▊         | 4/50 [01:30<17:16, 22.52s/it]

Progress:
	Train loss: 0.5214958524002749 | Train accuracy: 0.8090748190879822
	 Test loss: 0.5906581878662109 | Test accuracy: 0.7162288427352905

 10%|█         | 5/50 [01:52<16:50, 22.45s/it]

Progress:
	Train loss: 0.5055847877965254 | Train accuracy: 0.8242756724357605
	 Test loss: 0.553654134273529 | Test accuracy: 0.7693026065826416

 12%|█▏        | 6/50 [02:14<16:26, 22.42s/it]

Progress:
	Train loss: 0.49379977408577413 | Train accuracy: 0.8347905874252319
	 Test loss: 0.5204214453697205 | Test accuracy: 0.8086411952972412

 14%|█▍        | 7/50 [02:37<16:04, 22.44s/it]

Progress:
	Train loss: 0.4754554546054672 | Train accuracy: 0.8578776121139526
	 Test loss: 0.48393845558166504 | Test accuracy: 0.8597146272659302

 16%|█▌        | 8/50 [02:59<15:39, 22.37s/it]

Progress:
	Train loss: 0.46296201865462694 | Train accuracy: 0.8685067892074585
	 Test loss: 0.5023924112319946 | Test accuracy: 0.8281104564666748

 18%|█▊        | 9/50 [03:21<15:13, 22.27s/it]

Progress:
	Train loss: 0.4532005519551389 | Train accuracy: 0.881993293762207
	 Test loss: 0.46808651089668274 | Test accuracy: 0.8703827261924744

 20%|██        | 10/50 [03:43<14:50, 22.27s/it]

Progress:
	Train loss: 0.44017638967317696 | Train accuracy: 0.8936510682106018
	 Test loss: 0.49904772639274597 | Test accuracy: 0.8255767822265625

 22%|██▏       | 11/50 [04:05<14:25, 22.20s/it]

Progress:
	Train loss: 0.4284979819374926 | Train accuracy: 0.9074233174324036
	 Test loss: 0.4575294852256775 | Test accuracy: 0.8774503469467163

 24%|██▍       | 12/50 [04:28<14:05, 22.24s/it]

Progress:
	Train loss: 0.42503141480333667 | Train accuracy: 0.9091376662254333
	 Test loss: 0.4544410705566406 | Test accuracy: 0.8815842270851135

 26%|██▌       | 13/50 [04:50<13:47, 22.36s/it]

Progress:
	Train loss: 0.41285736946498647 | Train accuracy: 0.922109842300415
	 Test loss: 0.4442048668861389 | Test accuracy: 0.8889185190200806

 28%|██▊       | 14/50 [05:13<13:29, 22.48s/it]

Progress:
	Train loss: 0.4075095171437544 | Train accuracy: 0.9266815781593323
	 Test loss: 0.4438576102256775 | Test accuracy: 0.895319402217865

 30%|███       | 15/50 [05:36<13:08, 22.53s/it]

Progress:
	Train loss: 0.39963305522413817 | Train accuracy: 0.9354820251464844
	 Test loss: 0.43310099840164185 | Test accuracy: 0.9038538932800293

 32%|███▏      | 16/50 [05:58<12:43, 22.44s/it]

Progress:
	Train loss: 0.392889859921792 | Train accuracy: 0.9435396790504456
	 Test loss: 0.633719801902771 | Test accuracy: 0.6667555570602417

 34%|███▍      | 17/50 [06:21<12:21, 22.48s/it]

Progress:
	Train loss: 0.3905082191614544 | Train accuracy: 0.9428539276123047
	 Test loss: 0.4524683952331543 | Test accuracy: 0.8771836757659912

 36%|███▌      | 18/50 [06:43<11:59, 22.47s/it]

Progress:
	Train loss: 0.3798688128590584 | Train accuracy: 0.9549688696861267
	 Test loss: 0.4510241150856018 | Test accuracy: 0.8790505528450012

 38%|███▊      | 19/50 [07:06<11:37, 22.49s/it]

Progress:
	Train loss: 0.37739868286777944 | Train accuracy: 0.9566832780838013
	 Test loss: 0.42741650342941284 | Test accuracy: 0.9039872288703918

 40%|████      | 20/50 [07:28<11:14, 22.47s/it]

Progress:
	Train loss: 0.37075023396926765 | Train accuracy: 0.9621692895889282
	 Test loss: 0.42378416657447815 | Test accuracy: 0.910254716873169

 42%|████▏     | 21/50 [07:51<10:52, 22.48s/it]

Progress:
	Train loss: 0.36738205701112747 | Train accuracy: 0.9657123684883118
	 Test loss: 0.42504414916038513 | Test accuracy: 0.9063875675201416

 44%|████▍     | 22/50 [08:13<10:30, 22.52s/it]

Progress:
	Train loss: 0.3621034096269047 | Train accuracy: 0.9707412123680115
	 Test loss: 0.43916118144989014 | Test accuracy: 0.8891852498054504

 46%|████▌     | 23/50 [08:36<10:07, 22.50s/it]

Progress:
	Train loss: 0.35962211209184985 | Train accuracy: 0.972455620765686
	 Test loss: 0.45188212394714355 | Test accuracy: 0.8766502737998962

 48%|████▊     | 24/50 [08:58<09:44, 22.48s/it]

Progress:
	Train loss: 0.35675456681672263 | Train accuracy: 0.9747414588928223
	 Test loss: 0.41255107522010803 | Test accuracy: 0.916255533695221

 50%|█████     | 25/50 [09:21<09:21, 22.47s/it]

Progress:
	Train loss: 0.34915757179260254 | Train accuracy: 0.9826847314834595
	 Test loss: 0.4211217164993286 | Test accuracy: 0.910254716873169

 52%|█████▏    | 26/50 [09:43<08:58, 22.45s/it]

Progress:
	Train loss: 0.34847571394022775 | Train accuracy: 0.9825133085250854
	 Test loss: 0.4175620973110199 | Test accuracy: 0.908654510974884

 54%|█████▍    | 27/50 [10:05<08:35, 22.39s/it]

Progress:
	Train loss: 0.34599582938586965 | Train accuracy: 0.9839991331100464
	 Test loss: 0.40767788887023926 | Test accuracy: 0.9197226762771606

 56%|█████▌    | 28/50 [10:27<08:10, 22.30s/it]

Progress:
	Train loss: 0.3436327563489185 | Train accuracy: 0.9862849712371826
	 Test loss: 0.4263468384742737 | Test accuracy: 0.8998533487319946

 58%|█████▊    | 29/50 [10:49<07:46, 22.23s/it]

Progress:
	Train loss: 0.33961379308910933 | Train accuracy: 0.9895423054695129
	 Test loss: 0.4193621575832367 | Test accuracy: 0.9087878465652466

 60%|██████    | 30/50 [11:12<07:25, 22.30s/it]

Progress:
	Train loss: 0.33978962459984946 | Train accuracy: 0.9889708161354065
	 Test loss: 0.41484493017196655 | Test accuracy: 0.9106547832489014

 62%|██████▏   | 31/50 [11:34<07:05, 22.38s/it]

Progress:
	Train loss: 0.3370157339993645 | Train accuracy: 0.9911995530128479
	 Test loss: 0.41628292202949524 | Test accuracy: 0.9125217199325562

 64%|██████▍   | 32/50 [11:57<06:44, 22.48s/it]

Progress:
	Train loss: 0.3351951802478117 | Train accuracy: 0.9923995733261108
	 Test loss: 0.40881577134132385 | Test accuracy: 0.9197226762771606

 66%|██████▌   | 33/50 [12:20<06:23, 22.54s/it]

Progress:
	Train loss: 0.3322206718080184 | Train accuracy: 0.9947425723075867
	 Test loss: 0.41908589005470276 | Test accuracy: 0.9067875742912292

 68%|██████▊   | 34/50 [12:42<06:00, 22.53s/it]

Progress:
	Train loss: 0.3330441060311654 | Train accuracy: 0.9934853911399841
	 Test loss: 0.47761985659599304 | Test accuracy: 0.8441125750541687

 70%|███████   | 35/50 [13:05<05:37, 22.47s/it]

Progress:
	Train loss: 0.33092449167195487 | Train accuracy: 0.9954283237457275
	 Test loss: 0.4006243348121643 | Test accuracy: 0.9281237721443176

 72%|███████▏  | 36/50 [13:27<05:14, 22.50s/it]

Progress:
	Train loss: 0.3306383341550827 | Train accuracy: 0.9949139952659607
	 Test loss: 0.4103868007659912 | Test accuracy: 0.9153220653533936

 74%|███████▍  | 37/50 [13:50<04:52, 22.52s/it]

Progress:
	Train loss: 0.32983363858040643 | Train accuracy: 0.9960569739341736
	 Test loss: 0.40295925736427307 | Test accuracy: 0.9238565564155579

 76%|███████▌  | 38/50 [14:12<04:30, 22.54s/it]

Progress:
	Train loss: 0.32956593658994227 | Train accuracy: 0.9960569739341736
	 Test loss: 0.39952346682548523 | Test accuracy: 0.9275903701782227

 78%|███████▊  | 39/50 [14:35<04:08, 22.61s/it]

Progress:
	Train loss: 0.3280272851972019 | Train accuracy: 0.9967427253723145
	 Test loss: 0.4183642268180847 | Test accuracy: 0.9071876406669617

 80%|████████  | 40/50 [14:58<03:45, 22.60s/it]

Progress:
	Train loss: 0.3280867531019099 | Train accuracy: 0.9962283968925476
	 Test loss: 0.39965999126434326 | Test accuracy: 0.9279904365539551

 82%|████████▏ | 41/50 [15:20<03:22, 22.47s/it]

Progress:
	Train loss: 0.3261181230054182 | Train accuracy: 0.9975998997688293
	 Test loss: 0.4010853171348572 | Test accuracy: 0.9257234334945679

 84%|████████▍ | 42/50 [15:41<02:57, 22.15s/it]

Progress:
	Train loss: 0.326139736263191 | Train accuracy: 0.9974855780601501
	 Test loss: 0.407340943813324 | Test accuracy: 0.9181224703788757

 86%|████████▌ | 43/50 [16:03<02:33, 21.93s/it]

Progress:
	Train loss: 0.32631254327647824 | Train accuracy: 0.9971427321434021
	 Test loss: 0.3996972441673279 | Test accuracy: 0.9246566295623779

 88%|████████▊ | 44/50 [16:24<02:10, 21.77s/it]

Progress:
	Train loss: 0.3265145999543807 | Train accuracy: 0.9969713091850281
	 Test loss: 0.41349923610687256 | Test accuracy: 0.9119883179664612

 90%|█████████ | 45/50 [16:45<01:48, 21.66s/it]

Progress:
	Train loss: 0.325215905028231 | Train accuracy: 0.9977713227272034
	 Test loss: 0.39800992608070374 | Test accuracy: 0.9286571741104126

 92%|█████████▏| 46/50 [17:07<01:26, 21.60s/it]

Progress:
	Train loss: 0.3258204990450074 | Train accuracy: 0.9973141551017761
	 Test loss: 0.3998703360557556 | Test accuracy: 0.923589825630188

 94%|█████████▍| 47/50 [17:28<01:04, 21.55s/it]

Progress:
	Train loss: 0.3267129181939013 | Train accuracy: 0.9969141483306885
	 Test loss: 0.40143057703971863 | Test accuracy: 0.9269236326217651

 96%|█████████▌| 48/50 [17:50<00:43, 21.52s/it]

Progress:
	Train loss: 0.3257584584986462 | Train accuracy: 0.9970855712890625
	 Test loss: 0.40204235911369324 | Test accuracy: 0.9229230880737305

 98%|█████████▊| 49/50 [18:11<00:21, 21.47s/it]

Progress:
	Train loss: 0.3281830987509559 | Train accuracy: 0.9948568940162659
	 Test loss: 0.4025285243988037 | Test accuracy: 0.923589825630188

100%|██████████| 50/50 [18:32<00:00, 22.26s/it]

Progress:
	Train loss: 0.3270672314307269 | Train accuracy: 0.9957712292671204
	 Test loss: 0.4015301465988159 | Test accuracy: 0.9242566227912903




In [13]:
torch.save(model_2.state_dict(), "cnd_normed_92_acc.pt")