# Pytorch Implementation

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [2]:
from modules.CNN import CNN

In [3]:
from tqdm import tqdm

## Data

In [4]:
data_1_4 = fr"\Data\1_4"
data_5_8 = fr"\Data\5_8"
data_9_12= fr"\Data\9_12"

data_1_12= fr"Data\Final Testing Dataset\Final Testing Dataset"

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((40, 40)).to(device),
    transforms.RandomRotation(45).to(device)
])

## Model

### 1_12

In [7]:
root_dir = r'Data\Final Testing Dataset\Final Testing Dataset'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)

In [8]:
model = CNN(12).to(device)

In [9]:
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

In [10]:

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [11]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.23271313524021106,  	 Batch Loss:  1.794292688369751: 100%|██████████| 113/113 [00:09<00:00, 11.94it/s]


Accuracy:  0.3416666666666667 	Loss:  1.794292688369751
Epoch:  2


Running Accuracy: 0.35018050541516244,  	 Batch Loss:  1.9072198867797852: 100%|██████████| 113/113 [00:07<00:00, 15.30it/s]


Accuracy:  0.38333333333333336 	Loss:  1.9072198867797852
Epoch:  3


Running Accuracy: 0.40238822549291864,  	 Batch Loss:  1.4016845226287842: 100%|██████████| 113/113 [00:07<00:00, 15.42it/s]


Accuracy:  0.3625 	Loss:  1.4016845226287842
Epoch:  4


Running Accuracy: 0.47875590113857264,  	 Batch Loss:  1.2770018577575684: 100%|██████████| 113/113 [00:07<00:00, 15.43it/s]


Accuracy:  0.49083333333333334 	Loss:  1.2770018577575684
Epoch:  5


Running Accuracy: 0.542349347403499,  	 Batch Loss:  1.1132389307022095: 100%|██████████| 113/113 [00:07<00:00, 15.32it/s] 


Accuracy:  0.5241666666666667 	Loss:  1.1132389307022095
Epoch:  6


Running Accuracy: 0.5417939461260761,  	 Batch Loss:  1.3848941326141357: 100%|██████████| 113/113 [00:07<00:00, 15.46it/s]


Accuracy:  0.565 	Loss:  1.3848941326141357
Epoch:  7


Running Accuracy: 0.6064981949458483,  	 Batch Loss:  0.9637665748596191: 100%|██████████| 113/113 [00:07<00:00, 15.60it/s]


Accuracy:  0.5266666666666666 	Loss:  0.9637665748596191
Epoch:  8


Running Accuracy: 0.5923354623715634,  	 Batch Loss:  0.8349295258522034: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s]


Accuracy:  0.5808333333333333 	Loss:  0.8349295258522034
Epoch:  9


Running Accuracy: 0.6264926409330741,  	 Batch Loss:  0.8002030849456787: 100%|██████████| 113/113 [00:07<00:00, 15.46it/s]


Accuracy:  0.6133333333333333 	Loss:  0.8002030849456787
Epoch:  10


Running Accuracy: 0.6273257428492085,  	 Batch Loss:  0.9647778868675232: 100%|██████████| 113/113 [00:07<00:00, 15.71it/s]


Accuracy:  0.6008333333333333 	Loss:  0.9647778868675232
Epoch:  11


Running Accuracy: 0.6525965009719522,  	 Batch Loss:  0.6207951903343201: 100%|██████████| 113/113 [00:07<00:00, 15.69it/s]


Accuracy:  0.5558333333333333 	Loss:  0.6207951903343201
Epoch:  12


Running Accuracy: 0.6600944182171619,  	 Batch Loss:  0.8127725720405579: 100%|██████████| 113/113 [00:07<00:00, 15.72it/s]


Accuracy:  0.63 	Loss:  0.8127725720405579
Epoch:  13


Running Accuracy: 0.6737017495140238,  	 Batch Loss:  0.6031840443611145: 100%|██████████| 113/113 [00:07<00:00, 15.69it/s]


Accuracy:  0.6375 	Loss:  0.6031840443611145
Epoch:  14


Running Accuracy: 0.6692585392946404,  	 Batch Loss:  1.0582355260849: 100%|██████████| 113/113 [00:07<00:00, 15.98it/s]  


Accuracy:  0.6508333333333334 	Loss:  1.0582355260849
Epoch:  15


Running Accuracy: 0.6842543737850597,  	 Batch Loss:  0.6255459785461426: 100%|██████████| 113/113 [00:07<00:00, 15.67it/s]


Accuracy:  0.6116666666666667 	Loss:  0.6255459785461426
Epoch:  16


Running Accuracy: 0.690919189114135,  	 Batch Loss:  0.8226152062416077: 100%|██████████| 113/113 [00:07<00:00, 15.70it/s] 


Accuracy:  0.6391666666666667 	Loss:  0.8226152062416077
Epoch:  17


Running Accuracy: 0.7023049153013052,  	 Batch Loss:  0.4828568696975708: 100%|██████████| 113/113 [00:07<00:00, 15.64it/s]


Accuracy:  0.6316666666666667 	Loss:  0.4828568696975708
Epoch:  18


Running Accuracy: 0.68619827825604,  	 Batch Loss:  0.7510051727294922: 100%|██████████| 113/113 [00:07<00:00, 15.82it/s]  


Accuracy:  0.6408333333333334 	Loss:  0.7510051727294922
Epoch:  19


Running Accuracy: 0.7081366287142461,  	 Batch Loss:  0.771457850933075: 100%|██████████| 113/113 [00:07<00:00, 15.53it/s] 


Accuracy:  0.6675 	Loss:  0.771457850933075
Epoch:  20


Running Accuracy: 0.7136906414884754,  	 Batch Loss:  0.5629341006278992: 100%|██████████| 113/113 [00:07<00:00, 15.60it/s]


Accuracy:  0.6775 	Loss:  0.5629341006278992
Epoch:  21


Running Accuracy: 0.7136906414884754,  	 Batch Loss:  0.5644461512565613: 100%|██████████| 113/113 [00:07<00:00, 15.84it/s] 


Accuracy:  0.6658333333333334 	Loss:  0.5644461512565613
Epoch:  22


Running Accuracy: 0.7250763676756456,  	 Batch Loss:  0.763196587562561: 100%|██████████| 113/113 [00:07<00:00, 15.63it/s] 


Accuracy:  0.6466666666666666 	Loss:  0.763196587562561
Epoch:  23


Running Accuracy: 0.712857539572341,  	 Batch Loss:  0.912513256072998: 100%|██████████| 113/113 [00:07<00:00, 15.83it/s]  


Accuracy:  0.6125 	Loss:  0.912513256072998
Epoch:  24


Running Accuracy: 0.7200777561788392,  	 Batch Loss:  0.9176917672157288: 100%|██████████| 113/113 [00:07<00:00, 15.90it/s] 


Accuracy:  0.6816666666666666 	Loss:  0.9176917672157288
Epoch:  25


Running Accuracy: 0.7186892529852819,  	 Batch Loss:  0.7490048408508301: 100%|██████████| 113/113 [00:07<00:00, 15.67it/s]


Accuracy:  0.675 	Loss:  0.7490048408508301
Epoch:  26


Running Accuracy: 0.7356289919466815,  	 Batch Loss:  0.827994167804718: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s] 


Accuracy:  0.6791666666666667 	Loss:  0.827994167804718
Epoch:  27


Running Accuracy: 0.7322965842821438,  	 Batch Loss:  0.7915317416191101: 100%|██████████| 113/113 [00:07<00:00, 15.61it/s]


Accuracy:  0.6866666666666666 	Loss:  0.7915317416191101
Epoch:  28


Running Accuracy: 0.7409053040821993,  	 Batch Loss:  0.820891797542572: 100%|██████████| 113/113 [00:07<00:00, 15.52it/s] 


Accuracy:  0.6475 	Loss:  0.820891797542572
Epoch:  29


Running Accuracy: 0.7375728964176618,  	 Batch Loss:  0.5408692359924316: 100%|██████████| 113/113 [00:07<00:00, 15.48it/s]


Accuracy:  0.6908333333333333 	Loss:  0.5408692359924316
Epoch:  30


Running Accuracy: 0.7386836989725076,  	 Batch Loss:  0.7079589366912842: 100%|██████████| 113/113 [00:07<00:00, 15.32it/s]


Accuracy:  0.6616666666666666 	Loss:  0.7079589366912842
Epoch:  31


Running Accuracy: 0.7439600111080256,  	 Batch Loss:  0.8444207310676575: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s]


Accuracy:  0.6916666666666667 	Loss:  0.8444207310676575
Epoch:  32


Running Accuracy: 0.744237711746737,  	 Batch Loss:  0.9034938812255859: 100%|██████████| 113/113 [00:07<00:00, 15.41it/s] 


Accuracy:  0.7033333333333334 	Loss:  0.9034938812255859
Epoch:  33


Running Accuracy: 0.7517356289919467,  	 Batch Loss:  0.5437516570091248: 100%|██████████| 113/113 [00:07<00:00, 15.75it/s] 


Accuracy:  0.7 	Loss:  0.5437516570091248
Epoch:  34


Running Accuracy: 0.7503471257983894,  	 Batch Loss:  0.4798906445503235: 100%|██████████| 113/113 [00:07<00:00, 15.24it/s]


Accuracy:  0.7083333333333334 	Loss:  0.4798906445503235
Epoch:  35


Running Accuracy: 0.7642321577339628,  	 Batch Loss:  0.5668668746948242: 100%|██████████| 113/113 [00:07<00:00, 15.61it/s]


Accuracy:  0.6891666666666667 	Loss:  0.5668668746948242
Epoch:  36


Running Accuracy: 0.7572896417661761,  	 Batch Loss:  0.762477457523346: 100%|██████████| 113/113 [00:07<00:00, 15.73it/s] 


Accuracy:  0.7066666666666667 	Loss:  0.762477457523346
Epoch:  37


Running Accuracy: 0.7572896417661761,  	 Batch Loss:  0.8642481565475464: 100%|██████████| 113/113 [00:07<00:00, 15.06it/s] 


Accuracy:  0.66 	Loss:  0.8642481565475464
Epoch:  38


Running Accuracy: 0.7497917245209664,  	 Batch Loss:  0.48060479760169983: 100%|██████████| 113/113 [00:07<00:00, 15.79it/s]


Accuracy:  0.6766666666666666 	Loss:  0.48060479760169983
Epoch:  39


Running Accuracy: 0.7575673424048875,  	 Batch Loss:  0.25332507491111755: 100%|██████████| 113/113 [00:07<00:00, 15.82it/s]


Accuracy:  0.6916666666666667 	Loss:  0.25332507491111755
Epoch:  40


Running Accuracy: 0.7745070813662871,  	 Batch Loss:  0.5016249418258667: 100%|██████████| 113/113 [00:07<00:00, 15.96it/s]


Accuracy:  0.6775 	Loss:  0.5016249418258667
Epoch:  41


Running Accuracy: 0.7692307692307693,  	 Batch Loss:  0.5920419692993164: 100%|██████████| 113/113 [00:07<00:00, 15.92it/s]


Accuracy:  0.6933333333333334 	Loss:  0.5920419692993164
Epoch:  42


Running Accuracy: 0.7711746737017495,  	 Batch Loss:  0.558148205280304: 100%|██████████| 113/113 [00:07<00:00, 15.75it/s] 


Accuracy:  0.6275 	Loss:  0.558148205280304
Epoch:  43


Running Accuracy: 0.7706192724243266,  	 Batch Loss:  0.4317047894001007: 100%|██████████| 113/113 [00:07<00:00, 15.78it/s]


Accuracy:  0.6875 	Loss:  0.4317047894001007
Epoch:  44


Running Accuracy: 0.7908914190502638,  	 Batch Loss:  0.532651424407959: 100%|██████████| 113/113 [00:07<00:00, 15.89it/s] 


Accuracy:  0.7333333333333333 	Loss:  0.532651424407959
Epoch:  45


Running Accuracy: 0.7925576228825326,  	 Batch Loss:  0.5363736152648926: 100%|██████████| 113/113 [00:07<00:00, 15.79it/s]


Accuracy:  0.6875 	Loss:  0.5363736152648926
Epoch:  46


Running Accuracy: 0.7717300749791725,  	 Batch Loss:  0.48017334938049316: 100%|██████████| 113/113 [00:07<00:00, 15.96it/s]


Accuracy:  0.685 	Loss:  0.48017334938049316
Epoch:  47


Running Accuracy: 0.7947792279922243,  	 Batch Loss:  0.5164469480514526: 100%|██████████| 113/113 [00:07<00:00, 16.10it/s]


Accuracy:  0.7175 	Loss:  0.5164469480514526
Epoch:  48


Running Accuracy: 0.7831158011663427,  	 Batch Loss:  0.5234214067459106: 100%|██████████| 113/113 [00:07<00:00, 15.64it/s] 


Accuracy:  0.7075 	Loss:  0.5234214067459106
Epoch:  49


Running Accuracy: 0.7792279922243821,  	 Batch Loss:  0.7390831112861633: 100%|██████████| 113/113 [00:07<00:00, 15.76it/s]


Accuracy:  0.7083333333333334 	Loss:  0.7390831112861633
Epoch:  50


Running Accuracy: 0.7817272979727853,  	 Batch Loss:  0.6144466996192932: 100%|██████████| 113/113 [00:07<00:00, 15.90it/s]


Accuracy:  0.7075 	Loss:  0.6144466996192932


In [13]:
torch.save(model, '1_12_bats.pth')

### 1_4

In [14]:
root_dir = r'Data\Top_level\1_4'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_1_4 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_1_4.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_1_4(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_1_4(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.3064113238967527,  	 Batch Loss:  1.3490934371948242: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s] 


Accuracy:  0.2625 	Loss:  1.3490934371948242
Epoch:  2


Running Accuracy: 0.4496253122398002,  	 Batch Loss:  1.0003957748413086: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s] 


Accuracy:  0.55 	Loss:  1.0003957748413086
Epoch:  3


Running Accuracy: 0.5736885928393006,  	 Batch Loss:  0.6543607115745544: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]


Accuracy:  0.6075 	Loss:  0.6543607115745544
Epoch:  4


Running Accuracy: 0.611157368859284,  	 Batch Loss:  0.744815468788147: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]  


Accuracy:  0.6025 	Loss:  0.744815468788147
Epoch:  5


Running Accuracy: 0.6169858451290591,  	 Batch Loss:  0.6549387574195862: 100%|██████████| 38/38 [00:07<00:00,  5.43it/s]


Accuracy:  0.63 	Loss:  0.6549387574195862
Epoch:  6


Running Accuracy: 0.6394671107410491,  	 Batch Loss:  0.6839158535003662: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]


Accuracy:  0.6575 	Loss:  0.6839158535003662
Epoch:  7


Running Accuracy: 0.6477935054121565,  	 Batch Loss:  0.4989982545375824: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]


Accuracy:  0.6275 	Loss:  0.4989982545375824
Epoch:  8


Running Accuracy: 0.671940049958368,  	 Batch Loss:  0.5839374661445618: 100%|██████████| 38/38 [00:06<00:00,  5.51it/s] 


Accuracy:  0.6625 	Loss:  0.5839374661445618
Epoch:  9


Running Accuracy: 0.6794338051623647,  	 Batch Loss:  0.44876372814178467: 100%|██████████| 38/38 [00:06<00:00,  5.71it/s]


Accuracy:  0.685 	Loss:  0.44876372814178467
Epoch:  10


Running Accuracy: 0.6935886761032473,  	 Batch Loss:  0.5898594260215759: 100%|██████████| 38/38 [00:06<00:00,  5.82it/s]


Accuracy:  0.645 	Loss:  0.5898594260215759
Epoch:  11


Running Accuracy: 0.7035803497085762,  	 Batch Loss:  0.5538922548294067: 100%|██████████| 38/38 [00:06<00:00,  5.82it/s]


Accuracy:  0.7175 	Loss:  0.5538922548294067
Epoch:  12


Running Accuracy: 0.7243963363863447,  	 Batch Loss:  0.45531216263771057: 100%|██████████| 38/38 [00:06<00:00,  5.84it/s]


Accuracy:  0.7075 	Loss:  0.45531216263771057
Epoch:  13


Running Accuracy: 0.7218984179850125,  	 Batch Loss:  0.6436710953712463: 100%|██████████| 38/38 [00:06<00:00,  5.69it/s]


Accuracy:  0.7075 	Loss:  0.6436710953712463
Epoch:  14


Running Accuracy: 0.7368859283930058,  	 Batch Loss:  0.536884605884552: 100%|██████████| 38/38 [00:06<00:00,  5.51it/s] 


Accuracy:  0.75 	Loss:  0.536884605884552
Epoch:  15


Running Accuracy: 0.7468776019983348,  	 Batch Loss:  0.6328352093696594: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s] 


Accuracy:  0.775 	Loss:  0.6328352093696594
Epoch:  16


Running Accuracy: 0.7243963363863447,  	 Batch Loss:  0.7614548206329346: 100%|██████████| 38/38 [00:06<00:00,  5.77it/s]


Accuracy:  0.5875 	Loss:  0.7614548206329346
Epoch:  17


Running Accuracy: 0.7185678601165695,  	 Batch Loss:  0.7888888120651245: 100%|██████████| 38/38 [00:06<00:00,  5.73it/s] 


Accuracy:  0.685 	Loss:  0.7888888120651245
Epoch:  18


Running Accuracy: 0.7810158201498751,  	 Batch Loss:  0.34661003947257996: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s]


Accuracy:  0.715 	Loss:  0.34661003947257996
Epoch:  19


Running Accuracy: 0.7743547044129891,  	 Batch Loss:  0.4181738495826721: 100%|██████████| 38/38 [00:06<00:00,  5.69it/s] 


Accuracy:  0.72 	Loss:  0.4181738495826721
Epoch:  20


Running Accuracy: 0.7768526228143214,  	 Batch Loss:  0.4179763197898865: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s] 


Accuracy:  0.7675 	Loss:  0.4179763197898865
Epoch:  21


Running Accuracy: 0.7851790174854288,  	 Batch Loss:  0.5774780511856079: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s] 


Accuracy:  0.7725 	Loss:  0.5774780511856079
Epoch:  22


Running Accuracy: 0.8101582014987511,  	 Batch Loss:  0.32969143986701965: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s]


Accuracy:  0.79 	Loss:  0.32969143986701965
Epoch:  23


Running Accuracy: 0.7776852622814321,  	 Batch Loss:  0.5244135856628418: 100%|██████████| 38/38 [00:06<00:00,  5.73it/s] 


Accuracy:  0.7375 	Loss:  0.5244135856628418
Epoch:  24


Running Accuracy: 0.79766860949209,  	 Batch Loss:  0.4962545335292816: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s]   


Accuracy:  0.7875 	Loss:  0.4962545335292816
Epoch:  25


Running Accuracy: 0.8118234804329725,  	 Batch Loss:  0.5244314670562744: 100%|██████████| 38/38 [00:06<00:00,  5.70it/s] 


Accuracy:  0.7825 	Loss:  0.5244314670562744
Epoch:  26


Running Accuracy: 0.8076602830974188,  	 Batch Loss:  0.5250623822212219: 100%|██████████| 38/38 [00:06<00:00,  5.67it/s] 


Accuracy:  0.6975 	Loss:  0.5250623822212219
Epoch:  27


Running Accuracy: 0.791007493755204,  	 Batch Loss:  0.38533905148506165: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s] 


Accuracy:  0.7975 	Loss:  0.38533905148506165
Epoch:  28


Running Accuracy: 0.8226477935054122,  	 Batch Loss:  0.4914810061454773: 100%|██████████| 38/38 [00:06<00:00,  5.74it/s] 


Accuracy:  0.755 	Loss:  0.4914810061454773
Epoch:  29


Running Accuracy: 0.8043297252289758,  	 Batch Loss:  0.5156604647636414: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s] 


Accuracy:  0.8175 	Loss:  0.5156604647636414
Epoch:  30


Running Accuracy: 0.8168193172356369,  	 Batch Loss:  0.4173518121242523: 100%|██████████| 38/38 [00:06<00:00,  5.82it/s] 


Accuracy:  0.76 	Loss:  0.4173518121242523
Epoch:  31


Running Accuracy: 0.8284762697751873,  	 Batch Loss:  0.4741628170013428: 100%|██████████| 38/38 [00:06<00:00,  5.73it/s] 


Accuracy:  0.805 	Loss:  0.4741628170013428
Epoch:  32


Running Accuracy: 0.8259783513738551,  	 Batch Loss:  0.2316545844078064: 100%|██████████| 38/38 [00:06<00:00,  5.81it/s] 


Accuracy:  0.83 	Loss:  0.2316545844078064
Epoch:  33


Running Accuracy: 0.829308909242298,  	 Batch Loss:  0.39412426948547363: 100%|██████████| 38/38 [00:06<00:00,  5.85it/s] 


Accuracy:  0.82 	Loss:  0.39412426948547363
Epoch:  34


Running Accuracy: 0.8151540383014155,  	 Batch Loss:  0.16502290964126587: 100%|██████████| 38/38 [00:06<00:00,  5.73it/s]


Accuracy:  0.7775 	Loss:  0.16502290964126587
Epoch:  35


Running Accuracy: 0.8276436303080766,  	 Batch Loss:  0.2988996207714081: 100%|██████████| 38/38 [00:06<00:00,  5.68it/s] 


Accuracy:  0.8175 	Loss:  0.2988996207714081
Epoch:  36


Running Accuracy: 0.8259783513738551,  	 Batch Loss:  0.4904021620750427: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s] 


Accuracy:  0.8225 	Loss:  0.4904021620750427
Epoch:  37


Running Accuracy: 0.832639467110741,  	 Batch Loss:  0.3913227915763855: 100%|██████████| 38/38 [00:06<00:00,  5.85it/s]  


Accuracy:  0.8525 	Loss:  0.3913227915763855
Epoch:  38


Running Accuracy: 0.8476269775187344,  	 Batch Loss:  0.41258999705314636: 100%|██████████| 38/38 [00:06<00:00,  5.71it/s]


Accuracy:  0.785 	Loss:  0.41258999705314636
Epoch:  39


Running Accuracy: 0.8218151540383014,  	 Batch Loss:  0.6644977331161499: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s] 


Accuracy:  0.84 	Loss:  0.6644977331161499
Epoch:  40


Running Accuracy: 0.8467943380516236,  	 Batch Loss:  0.5095663070678711: 100%|██████████| 38/38 [00:06<00:00,  5.83it/s] 


Accuracy:  0.825 	Loss:  0.5095663070678711
Epoch:  41


Running Accuracy: 0.8501248959200666,  	 Batch Loss:  0.12554439902305603: 100%|██████████| 38/38 [00:06<00:00,  5.51it/s]


Accuracy:  0.8425 	Loss:  0.12554439902305603
Epoch:  42


Running Accuracy: 0.8434637801831807,  	 Batch Loss:  0.4028525948524475: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s] 


Accuracy:  0.8325 	Loss:  0.4028525948524475
Epoch:  43


Running Accuracy: 0.8484596169858452,  	 Batch Loss:  0.29306459426879883: 100%|██████████| 38/38 [00:06<00:00,  5.67it/s]


Accuracy:  0.8225 	Loss:  0.29306459426879883
Epoch:  44


Running Accuracy: 0.8467943380516236,  	 Batch Loss:  0.1738164722919464: 100%|██████████| 38/38 [00:07<00:00,  5.28it/s] 


Accuracy:  0.82 	Loss:  0.1738164722919464
Epoch:  45


Running Accuracy: 0.8501248959200666,  	 Batch Loss:  0.3759496212005615: 100%|██████████| 38/38 [00:06<00:00,  5.72it/s] 


Accuracy:  0.8425 	Loss:  0.3759496212005615
Epoch:  46


Running Accuracy: 0.8484596169858452,  	 Batch Loss:  0.2843346893787384: 100%|██████████| 38/38 [00:06<00:00,  5.78it/s] 


Accuracy:  0.7925 	Loss:  0.2843346893787384
Epoch:  47


Running Accuracy: 0.8467943380516236,  	 Batch Loss:  0.78061443567276: 100%|██████████| 38/38 [00:06<00:00,  5.68it/s]   


Accuracy:  0.835 	Loss:  0.78061443567276
Epoch:  48


Running Accuracy: 0.8426311407160699,  	 Batch Loss:  0.2580624520778656: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s] 


Accuracy:  0.8475 	Loss:  0.2580624520778656
Epoch:  49


Running Accuracy: 0.8343047460449625,  	 Batch Loss:  0.1781231164932251: 100%|██████████| 38/38 [00:06<00:00,  5.78it/s] 


Accuracy:  0.8425 	Loss:  0.1781231164932251
Epoch:  50


Running Accuracy: 0.8559533721898418,  	 Batch Loss:  0.3188185691833496: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s] 


Accuracy:  0.8075 	Loss:  0.3188185691833496


### 5_8

In [15]:
root_dir = r'Data\Top_level\5_8'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_5_8 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_5_8.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_5_8(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_5_8(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.2775,  	 Batch Loss:  1.382874608039856: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]              


Accuracy:  0.2725 	Loss:  1.382874608039856
Epoch:  2


Running Accuracy: 0.37916666666666665,  	 Batch Loss:  1.2245548963546753: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s]


Accuracy:  0.425 	Loss:  1.2245548963546753
Epoch:  3


Running Accuracy: 0.43,  	 Batch Loss:  1.1750215291976929: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s]               


Accuracy:  0.48 	Loss:  1.1750215291976929
Epoch:  4


Running Accuracy: 0.5858333333333333,  	 Batch Loss:  0.9743655323982239: 100%|██████████| 38/38 [00:07<00:00,  5.37it/s]


Accuracy:  0.6025 	Loss:  0.9743655323982239
Epoch:  5


Running Accuracy: 0.61,  	 Batch Loss:  0.7244129776954651: 100%|██████████| 38/38 [00:06<00:00,  5.55it/s]              


Accuracy:  0.5825 	Loss:  0.7244129776954651
Epoch:  6


Running Accuracy: 0.6625,  	 Batch Loss:  0.4462086260318756: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]            


Accuracy:  0.665 	Loss:  0.4462086260318756
Epoch:  7


Running Accuracy: 0.6775,  	 Batch Loss:  0.6674244403839111: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]            


Accuracy:  0.64 	Loss:  0.6674244403839111
Epoch:  8


Running Accuracy: 0.7158333333333333,  	 Batch Loss:  0.6011888980865479: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.7 	Loss:  0.6011888980865479
Epoch:  9


Running Accuracy: 0.7166666666666667,  	 Batch Loss:  0.8868562579154968: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]


Accuracy:  0.67 	Loss:  0.8868562579154968
Epoch:  10


Running Accuracy: 0.71,  	 Batch Loss:  0.46963000297546387: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]             


Accuracy:  0.6925 	Loss:  0.46963000297546387
Epoch:  11


Running Accuracy: 0.7283333333333334,  	 Batch Loss:  0.7116024494171143: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]


Accuracy:  0.6975 	Loss:  0.7116024494171143
Epoch:  12


Running Accuracy: 0.7216666666666667,  	 Batch Loss:  0.8679160475730896: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s] 


Accuracy:  0.6875 	Loss:  0.8679160475730896
Epoch:  13


Running Accuracy: 0.7283333333333334,  	 Batch Loss:  0.6531089544296265: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]


Accuracy:  0.6975 	Loss:  0.6531089544296265
Epoch:  14


Running Accuracy: 0.7525,  	 Batch Loss:  0.6980422139167786: 100%|██████████| 38/38 [00:06<00:00,  5.64it/s]             


Accuracy:  0.7275 	Loss:  0.6980422139167786
Epoch:  15


Running Accuracy: 0.7558333333333334,  	 Batch Loss:  0.812866747379303: 100%|██████████| 38/38 [00:06<00:00,  5.67it/s]  


Accuracy:  0.7075 	Loss:  0.812866747379303
Epoch:  16


Running Accuracy: 0.7441666666666666,  	 Batch Loss:  0.5125388503074646: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s] 


Accuracy:  0.69 	Loss:  0.5125388503074646
Epoch:  17


Running Accuracy: 0.7383333333333333,  	 Batch Loss:  0.5219888687133789: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s] 


Accuracy:  0.7025 	Loss:  0.5219888687133789
Epoch:  18


Running Accuracy: 0.76,  	 Batch Loss:  0.6077507734298706: 100%|██████████| 38/38 [00:06<00:00,  5.64it/s]               


Accuracy:  0.725 	Loss:  0.6077507734298706
Epoch:  19


Running Accuracy: 0.75,  	 Batch Loss:  0.8724005818367004: 100%|██████████| 38/38 [00:06<00:00,  5.59it/s]              


Accuracy:  0.715 	Loss:  0.8724005818367004
Epoch:  20


Running Accuracy: 0.7608333333333334,  	 Batch Loss:  0.8054170608520508: 100%|██████████| 38/38 [00:06<00:00,  5.67it/s] 


Accuracy:  0.7225 	Loss:  0.8054170608520508
Epoch:  21


Running Accuracy: 0.76,  	 Batch Loss:  0.43332505226135254: 100%|██████████| 38/38 [00:06<00:00,  5.55it/s]              


Accuracy:  0.7125 	Loss:  0.43332505226135254
Epoch:  22


Running Accuracy: 0.7683333333333333,  	 Batch Loss:  0.6501283049583435: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s] 


Accuracy:  0.7325 	Loss:  0.6501283049583435
Epoch:  23


Running Accuracy: 0.7616666666666667,  	 Batch Loss:  0.6379033923149109: 100%|██████████| 38/38 [00:06<00:00,  5.45it/s] 


Accuracy:  0.74 	Loss:  0.6379033923149109
Epoch:  24


Running Accuracy: 0.7791666666666667,  	 Batch Loss:  0.4628298282623291: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s] 


Accuracy:  0.7425 	Loss:  0.4628298282623291
Epoch:  25


Running Accuracy: 0.79,  	 Batch Loss:  0.668923020362854: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s]                


Accuracy:  0.7525 	Loss:  0.668923020362854
Epoch:  26


Running Accuracy: 0.7866666666666666,  	 Batch Loss:  0.7308411598205566: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s] 


Accuracy:  0.7325 	Loss:  0.7308411598205566
Epoch:  27


Running Accuracy: 0.7966666666666666,  	 Batch Loss:  0.8542525172233582: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s] 


Accuracy:  0.7375 	Loss:  0.8542525172233582
Epoch:  28


Running Accuracy: 0.7883333333333333,  	 Batch Loss:  0.5871992111206055: 100%|██████████| 38/38 [00:06<00:00,  5.51it/s] 


Accuracy:  0.71 	Loss:  0.5871992111206055
Epoch:  29


Running Accuracy: 0.78,  	 Batch Loss:  0.6921902894973755: 100%|██████████| 38/38 [00:07<00:00,  5.36it/s]               


Accuracy:  0.7175 	Loss:  0.6921902894973755
Epoch:  30


Running Accuracy: 0.82,  	 Batch Loss:  0.3597928583621979: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s]               


Accuracy:  0.73 	Loss:  0.3597928583621979
Epoch:  31


Running Accuracy: 0.7841666666666667,  	 Batch Loss:  0.23085832595825195: 100%|██████████| 38/38 [00:06<00:00,  5.60it/s]


Accuracy:  0.7325 	Loss:  0.23085832595825195
Epoch:  32


Running Accuracy: 0.81,  	 Batch Loss:  0.40621039271354675: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]              


Accuracy:  0.7425 	Loss:  0.40621039271354675
Epoch:  33


Running Accuracy: 0.7983333333333333,  	 Batch Loss:  0.5557562112808228: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s] 


Accuracy:  0.7425 	Loss:  0.5557562112808228
Epoch:  34


Running Accuracy: 0.8041666666666667,  	 Batch Loss:  0.45666801929473877: 100%|██████████| 38/38 [00:06<00:00,  5.67it/s]


Accuracy:  0.79 	Loss:  0.45666801929473877
Epoch:  35


Running Accuracy: 0.7891666666666667,  	 Batch Loss:  0.3102540671825409: 100%|██████████| 38/38 [00:06<00:00,  5.64it/s] 


Accuracy:  0.755 	Loss:  0.3102540671825409
Epoch:  36


Running Accuracy: 0.8083333333333333,  	 Batch Loss:  0.20438016951084137: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.7975 	Loss:  0.20438016951084137
Epoch:  37


Running Accuracy: 0.805,  	 Batch Loss:  0.5123772025108337: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]              


Accuracy:  0.7575 	Loss:  0.5123772025108337
Epoch:  38


Running Accuracy: 0.8183333333333334,  	 Batch Loss:  0.3457542359828949: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s] 


Accuracy:  0.7725 	Loss:  0.3457542359828949
Epoch:  39


Running Accuracy: 0.8166666666666667,  	 Batch Loss:  0.303949773311615: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]  


Accuracy:  0.7875 	Loss:  0.303949773311615
Epoch:  40


Running Accuracy: 0.8266666666666667,  	 Batch Loss:  0.4456564784049988: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s] 


Accuracy:  0.7875 	Loss:  0.4456564784049988
Epoch:  41


Running Accuracy: 0.8191666666666667,  	 Batch Loss:  0.3338683247566223: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s] 


Accuracy:  0.7525 	Loss:  0.3338683247566223
Epoch:  42


Running Accuracy: 0.8133333333333334,  	 Batch Loss:  0.30857598781585693: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s]


Accuracy:  0.795 	Loss:  0.30857598781585693
Epoch:  43


Running Accuracy: 0.8283333333333334,  	 Batch Loss:  0.5112588405609131: 100%|██████████| 38/38 [00:07<00:00,  5.39it/s] 


Accuracy:  0.755 	Loss:  0.5112588405609131
Epoch:  44


Running Accuracy: 0.8325,  	 Batch Loss:  0.2938425540924072: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]             


Accuracy:  0.7875 	Loss:  0.2938425540924072
Epoch:  45


Running Accuracy: 0.8375,  	 Batch Loss:  0.32689347863197327: 100%|██████████| 38/38 [00:06<00:00,  5.71it/s]            


Accuracy:  0.805 	Loss:  0.32689347863197327
Epoch:  46


Running Accuracy: 0.8366666666666667,  	 Batch Loss:  0.44292160868644714: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s]


Accuracy:  0.8025 	Loss:  0.44292160868644714
Epoch:  47


Running Accuracy: 0.84,  	 Batch Loss:  0.31570279598236084: 100%|██████████| 38/38 [00:06<00:00,  5.59it/s]              


Accuracy:  0.7875 	Loss:  0.31570279598236084
Epoch:  48


Running Accuracy: 0.8491666666666666,  	 Batch Loss:  0.3269840180873871: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s] 


Accuracy:  0.7875 	Loss:  0.3269840180873871
Epoch:  49


Running Accuracy: 0.85,  	 Batch Loss:  0.15433791279792786: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]              


Accuracy:  0.8025 	Loss:  0.15433791279792786
Epoch:  50


Running Accuracy: 0.8366666666666667,  	 Batch Loss:  0.5248077511787415: 100%|██████████| 38/38 [00:06<00:00,  5.51it/s] 


Accuracy:  0.8075 	Loss:  0.5248077511787415


### 9_12

In [16]:
root_dir = r'Data\Top_level\9_12'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_9_12 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_9_12.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_9_12(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_9_12(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


  0%|          | 0/38 [00:00<?, ?it/s]

Running Accuracy: 0.23416666666666666,  	 Batch Loss:  1.4087969064712524: 100%|██████████| 38/38 [00:06<00:00,  5.45it/s]


Accuracy:  0.2525 	Loss:  1.4087969064712524
Epoch:  2


Running Accuracy: 0.25083333333333335,  	 Batch Loss:  1.3913350105285645: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]


Accuracy:  0.2525 	Loss:  1.3913350105285645
Epoch:  3


Running Accuracy: 0.25916666666666666,  	 Batch Loss:  1.3717799186706543: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s]


Accuracy:  0.2525 	Loss:  1.3717799186706543
Epoch:  4


Running Accuracy: 0.30333333333333334,  	 Batch Loss:  1.3770785331726074: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s]


Accuracy:  0.3025 	Loss:  1.3770785331726074
Epoch:  5


Running Accuracy: 0.3333333333333333,  	 Batch Loss:  1.3184020519256592: 100%|██████████| 38/38 [00:06<00:00,  5.68it/s] 


Accuracy:  0.3525 	Loss:  1.3184020519256592
Epoch:  6


Running Accuracy: 0.37916666666666665,  	 Batch Loss:  1.2368170022964478: 100%|██████████| 38/38 [00:06<00:00,  5.69it/s]


Accuracy:  0.39 	Loss:  1.2368170022964478
Epoch:  7


Running Accuracy: 0.4275,  	 Batch Loss:  1.3262685537338257: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]             


Accuracy:  0.4725 	Loss:  1.3262685537338257
Epoch:  8


Running Accuracy: 0.4741666666666667,  	 Batch Loss:  0.9944013953208923: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s] 


Accuracy:  0.4475 	Loss:  0.9944013953208923
Epoch:  9


Running Accuracy: 0.46,  	 Batch Loss:  0.989599347114563: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]                


Accuracy:  0.465 	Loss:  0.989599347114563
Epoch:  10


Running Accuracy: 0.5091666666666667,  	 Batch Loss:  1.0244166851043701: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s] 


Accuracy:  0.5025 	Loss:  1.0244166851043701
Epoch:  11


Running Accuracy: 0.5333333333333333,  	 Batch Loss:  0.8614354133605957: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.5175 	Loss:  0.8614354133605957
Epoch:  12


Running Accuracy: 0.5625,  	 Batch Loss:  0.8789034485816956: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]            


Accuracy:  0.5625 	Loss:  0.8789034485816956
Epoch:  13


Running Accuracy: 0.56,  	 Batch Loss:  0.8926329612731934: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]              


Accuracy:  0.5375 	Loss:  0.8926329612731934
Epoch:  14


Running Accuracy: 0.565,  	 Batch Loss:  0.947144091129303: 100%|██████████| 38/38 [00:06<00:00,  5.60it/s]              


Accuracy:  0.575 	Loss:  0.947144091129303
Epoch:  15


Running Accuracy: 0.58,  	 Batch Loss:  1.0513465404510498: 100%|██████████| 38/38 [00:06<00:00,  5.64it/s]              


Accuracy:  0.52 	Loss:  1.0513465404510498
Epoch:  16


Running Accuracy: 0.5575,  	 Batch Loss:  1.1584076881408691: 100%|██████████| 38/38 [00:06<00:00,  5.68it/s]            


Accuracy:  0.545 	Loss:  1.1584076881408691
Epoch:  17


Running Accuracy: 0.5533333333333333,  	 Batch Loss:  1.062120795249939: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s] 


Accuracy:  0.5075 	Loss:  1.062120795249939
Epoch:  18


Running Accuracy: 0.5525,  	 Batch Loss:  0.7228960394859314: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]            


Accuracy:  0.555 	Loss:  0.7228960394859314
Epoch:  19


Running Accuracy: 0.5641666666666667,  	 Batch Loss:  0.8055214881896973: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]


Accuracy:  0.58 	Loss:  0.8055214881896973
Epoch:  20


Running Accuracy: 0.5808333333333333,  	 Batch Loss:  0.9248278141021729: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s]


Accuracy:  0.55 	Loss:  0.9248278141021729
Epoch:  21


Running Accuracy: 0.5641666666666667,  	 Batch Loss:  0.8922404050827026: 100%|██████████| 38/38 [00:07<00:00,  5.34it/s]


Accuracy:  0.59 	Loss:  0.8922404050827026
Epoch:  22


Running Accuracy: 0.5883333333333334,  	 Batch Loss:  0.8652648329734802: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.6025 	Loss:  0.8652648329734802
Epoch:  23


Running Accuracy: 0.5658333333333333,  	 Batch Loss:  0.9306521415710449: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]


Accuracy:  0.5675 	Loss:  0.9306521415710449
Epoch:  24


Running Accuracy: 0.5741666666666667,  	 Batch Loss:  0.5859009623527527: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]


Accuracy:  0.55 	Loss:  0.5859009623527527
Epoch:  25


Running Accuracy: 0.585,  	 Batch Loss:  0.804582417011261: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]              


Accuracy:  0.5425 	Loss:  0.804582417011261
Epoch:  26


Running Accuracy: 0.5841666666666666,  	 Batch Loss:  0.6992474794387817: 100%|██████████| 38/38 [00:06<00:00,  5.50it/s]


Accuracy:  0.6025 	Loss:  0.6992474794387817
Epoch:  27


Running Accuracy: 0.6058333333333333,  	 Batch Loss:  0.8350667953491211: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.5825 	Loss:  0.8350667953491211
Epoch:  28


Running Accuracy: 0.5725,  	 Batch Loss:  0.8299905061721802: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]            


Accuracy:  0.58 	Loss:  0.8299905061721802
Epoch:  29


Running Accuracy: 0.5925,  	 Batch Loss:  0.7740203142166138: 100%|██████████| 38/38 [00:06<00:00,  5.59it/s]            


Accuracy:  0.5825 	Loss:  0.7740203142166138
Epoch:  30


Running Accuracy: 0.5725,  	 Batch Loss:  1.055105209350586: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]             


Accuracy:  0.545 	Loss:  1.055105209350586
Epoch:  31


Running Accuracy: 0.6141666666666666,  	 Batch Loss:  1.0004467964172363: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]


Accuracy:  0.59 	Loss:  1.0004467964172363
Epoch:  32


Running Accuracy: 0.6066666666666667,  	 Batch Loss:  0.9482566714286804: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]


Accuracy:  0.515 	Loss:  0.9482566714286804
Epoch:  33


Running Accuracy: 0.6033333333333334,  	 Batch Loss:  0.7275375723838806: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.575 	Loss:  0.7275375723838806
Epoch:  34


Running Accuracy: 0.615,  	 Batch Loss:  0.6794622540473938: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]             


Accuracy:  0.53 	Loss:  0.6794622540473938
Epoch:  35


Running Accuracy: 0.5858333333333333,  	 Batch Loss:  0.7255483865737915: 100%|██████████| 38/38 [00:06<00:00,  5.55it/s]


Accuracy:  0.56 	Loss:  0.7255483865737915
Epoch:  36


Running Accuracy: 0.61,  	 Batch Loss:  0.8182905316352844: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]              


Accuracy:  0.5475 	Loss:  0.8182905316352844
Epoch:  37


Running Accuracy: 0.6066666666666667,  	 Batch Loss:  0.9436962604522705: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.59 	Loss:  0.9436962604522705
Epoch:  38


Running Accuracy: 0.6016666666666667,  	 Batch Loss:  0.7099394798278809: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.6025 	Loss:  0.7099394798278809
Epoch:  39


Running Accuracy: 0.6333333333333333,  	 Batch Loss:  0.8101033568382263: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.565 	Loss:  0.8101033568382263
Epoch:  40


Running Accuracy: 0.6083333333333333,  	 Batch Loss:  0.9963012337684631: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]


Accuracy:  0.55 	Loss:  0.9963012337684631
Epoch:  41


Running Accuracy: 0.6,  	 Batch Loss:  0.994367778301239: 100%|██████████| 38/38 [00:06<00:00,  5.66it/s]                


Accuracy:  0.5825 	Loss:  0.994367778301239
Epoch:  42


Running Accuracy: 0.6233333333333333,  	 Batch Loss:  0.9089406132698059: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]


Accuracy:  0.5825 	Loss:  0.9089406132698059
Epoch:  43


Running Accuracy: 0.6191666666666666,  	 Batch Loss:  0.6494427919387817: 100%|██████████| 38/38 [00:06<00:00,  5.61it/s]


Accuracy:  0.5825 	Loss:  0.6494427919387817
Epoch:  44


Running Accuracy: 0.6333333333333333,  	 Batch Loss:  0.44403019547462463: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]


Accuracy:  0.59 	Loss:  0.44403019547462463
Epoch:  45


Running Accuracy: 0.6408333333333334,  	 Batch Loss:  0.6361261606216431: 100%|██████████| 38/38 [00:06<00:00,  5.64it/s]


Accuracy:  0.6025 	Loss:  0.6361261606216431
Epoch:  46


Running Accuracy: 0.6333333333333333,  	 Batch Loss:  0.9299378395080566: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.615 	Loss:  0.9299378395080566
Epoch:  47


Running Accuracy: 0.6308333333333334,  	 Batch Loss:  0.9793525338172913: 100%|██████████| 38/38 [00:06<00:00,  5.60it/s] 


Accuracy:  0.615 	Loss:  0.9793525338172913
Epoch:  48


Running Accuracy: 0.6266666666666667,  	 Batch Loss:  0.7277253270149231: 100%|██████████| 38/38 [00:07<00:00,  5.37it/s]


Accuracy:  0.55 	Loss:  0.7277253270149231
Epoch:  49


Running Accuracy: 0.61,  	 Batch Loss:  0.6927796006202698: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]              


Accuracy:  0.5925 	Loss:  0.6927796006202698
Epoch:  50


Running Accuracy: 0.6391666666666667,  	 Batch Loss:  0.7316828370094299: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s]


Accuracy:  0.6 	Loss:  0.7316828370094299


### Top

In [17]:
root_dir = r'Data\Top_level'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_top = CNN(3).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_top.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_top(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_top(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.4520966398222716,  	 Batch Loss:  0.999234676361084: 100%|██████████| 113/113 [00:07<00:00, 15.82it/s]  


Accuracy:  0.5216666666666666 	Loss:  0.999234676361084
Epoch:  2


Running Accuracy: 0.6203832268814218,  	 Batch Loss:  0.6325085759162903: 100%|██████████| 113/113 [00:07<00:00, 15.59it/s]


Accuracy:  0.7008333333333333 	Loss:  0.6325085759162903
Epoch:  3


Running Accuracy: 0.7236878644820883,  	 Batch Loss:  0.5373547673225403: 100%|██████████| 113/113 [00:07<00:00, 15.83it/s]


Accuracy:  0.6891666666666667 	Loss:  0.5373547673225403
Epoch:  4


Running Accuracy: 0.7389613996112191,  	 Batch Loss:  0.6747811436653137: 100%|██████████| 113/113 [00:07<00:00, 15.75it/s]


Accuracy:  0.7741666666666667 	Loss:  0.6747811436653137
Epoch:  5


Running Accuracy: 0.7589558455984449,  	 Batch Loss:  0.4136815369129181: 100%|██████████| 113/113 [00:07<00:00, 15.93it/s]


Accuracy:  0.7666666666666667 	Loss:  0.4136815369129181
Epoch:  6


Running Accuracy: 0.7745070813662871,  	 Batch Loss:  0.4479196071624756: 100%|██████████| 113/113 [00:07<00:00, 15.62it/s] 


Accuracy:  0.8075 	Loss:  0.4479196071624756
Epoch:  7


Running Accuracy: 0.7839489030824771,  	 Batch Loss:  0.5304334759712219: 100%|██████████| 113/113 [00:07<00:00, 15.74it/s]


Accuracy:  0.7941666666666667 	Loss:  0.5304334759712219
Epoch:  8


Running Accuracy: 0.798667036934185,  	 Batch Loss:  0.3233571946620941: 100%|██████████| 113/113 [00:07<00:00, 15.83it/s]  


Accuracy:  0.8233333333333334 	Loss:  0.3233571946620941
Epoch:  9


Running Accuracy: 0.8056095529019717,  	 Batch Loss:  0.3118130564689636: 100%|██████████| 113/113 [00:07<00:00, 15.98it/s] 


Accuracy:  0.8375 	Loss:  0.3118130564689636
Epoch:  10


Running Accuracy: 0.806998056095529,  	 Batch Loss:  0.4382355213165283: 100%|██████████| 113/113 [00:07<00:00, 15.86it/s] 


Accuracy:  0.8291666666666667 	Loss:  0.4382355213165283
Epoch:  11


Running Accuracy: 0.8150513746181616,  	 Batch Loss:  0.5918312668800354: 100%|██████████| 113/113 [00:07<00:00, 15.81it/s] 


Accuracy:  0.78 	Loss:  0.5918312668800354
Epoch:  12


Running Accuracy: 0.8169952790891419,  	 Batch Loss:  0.5319418907165527: 100%|██████████| 113/113 [00:07<00:00, 15.87it/s]


Accuracy:  0.81 	Loss:  0.5319418907165527
Epoch:  13


Running Accuracy: 0.8169952790891419,  	 Batch Loss:  0.38081496953964233: 100%|██████████| 113/113 [00:07<00:00, 15.92it/s]


Accuracy:  0.8341666666666666 	Loss:  0.38081496953964233
Epoch:  14


Running Accuracy: 0.8256039988891974,  	 Batch Loss:  0.16306012868881226: 100%|██████████| 113/113 [00:07<00:00, 15.91it/s]


Accuracy:  0.8416666666666667 	Loss:  0.16306012868881226
Epoch:  15


Running Accuracy: 0.8272702027214662,  	 Batch Loss:  0.3580331802368164: 100%|██████████| 113/113 [00:07<00:00, 15.93it/s] 


Accuracy:  0.7975 	Loss:  0.3580331802368164
Epoch:  16


Running Accuracy: 0.8317134129408498,  	 Batch Loss:  0.3416227102279663: 100%|██████████| 113/113 [00:07<00:00, 15.78it/s] 


Accuracy:  0.8416666666666667 	Loss:  0.3416227102279663
Epoch:  17


Running Accuracy: 0.8283810052763121,  	 Batch Loss:  0.3906089961528778: 100%|██████████| 113/113 [00:07<00:00, 15.73it/s] 


Accuracy:  0.8433333333333334 	Loss:  0.3906089961528778
Epoch:  18


Running Accuracy: 0.8222715912246599,  	 Batch Loss:  0.36499783396720886: 100%|██████████| 113/113 [00:07<00:00, 15.99it/s]


Accuracy:  0.81 	Loss:  0.36499783396720886
Epoch:  19


Running Accuracy: 0.8325465148569842,  	 Batch Loss:  0.2268846184015274: 100%|██████████| 113/113 [00:07<00:00, 16.09it/s]


Accuracy:  0.85 	Loss:  0.2268846184015274
Epoch:  20


Running Accuracy: 0.8353235212440988,  	 Batch Loss:  0.22283309698104858: 100%|██████████| 113/113 [00:07<00:00, 16.08it/s]


Accuracy:  0.845 	Loss:  0.22283309698104858
Epoch:  21


Running Accuracy: 0.843654540405443,  	 Batch Loss:  0.5659303069114685: 100%|██████████| 113/113 [00:07<00:00, 16.05it/s]  


Accuracy:  0.8466666666666667 	Loss:  0.5659303069114685
Epoch:  22


Running Accuracy: 0.8455984448764232,  	 Batch Loss:  0.5684227347373962: 100%|██████████| 113/113 [00:06<00:00, 16.29it/s] 


Accuracy:  0.8241666666666667 	Loss:  0.5684227347373962
Epoch:  23


Running Accuracy: 0.8444876423215774,  	 Batch Loss:  0.27238938212394714: 100%|██████████| 113/113 [00:07<00:00, 16.10it/s]


Accuracy:  0.8233333333333334 	Loss:  0.27238938212394714
Epoch:  24


Running Accuracy: 0.843654540405443,  	 Batch Loss:  0.24315780401229858: 100%|██████████| 113/113 [00:07<00:00, 15.74it/s]


Accuracy:  0.8483333333333334 	Loss:  0.24315780401229858
Epoch:  25


Running Accuracy: 0.8458761455151347,  	 Batch Loss:  0.20063602924346924: 100%|██████████| 113/113 [00:07<00:00, 16.09it/s]


Accuracy:  0.8441666666666666 	Loss:  0.20063602924346924
Epoch:  26


Running Accuracy: 0.8461538461538461,  	 Batch Loss:  0.40305402874946594: 100%|██████████| 113/113 [00:07<00:00, 15.84it/s]


Accuracy:  0.8516666666666667 	Loss:  0.40305402874946594
Epoch:  27


Running Accuracy: 0.8528186614829214,  	 Batch Loss:  0.4437718987464905: 100%|██████████| 113/113 [00:06<00:00, 16.23it/s] 


Accuracy:  0.8458333333333333 	Loss:  0.4437718987464905
Epoch:  28


Running Accuracy: 0.8480977506248264,  	 Batch Loss:  0.2361593246459961: 100%|██████████| 113/113 [00:07<00:00, 15.93it/s] 


Accuracy:  0.8483333333333334 	Loss:  0.2361593246459961
Epoch:  29


Running Accuracy: 0.8530963621216329,  	 Batch Loss:  0.4218403697013855: 100%|██████████| 113/113 [00:07<00:00, 15.87it/s]


Accuracy:  0.8533333333333334 	Loss:  0.4218403697013855
Epoch:  30


Running Accuracy: 0.8469869480699805,  	 Batch Loss:  0.2393733710050583: 100%|██████████| 113/113 [00:07<00:00, 16.01it/s]


Accuracy:  0.8441666666666666 	Loss:  0.2393733710050583
Epoch:  31


Running Accuracy: 0.8508747570119412,  	 Batch Loss:  0.8397684097290039: 100%|██████████| 113/113 [00:07<00:00, 16.14it/s]


Accuracy:  0.8516666666666667 	Loss:  0.8397684097290039
Epoch:  32


Running Accuracy: 0.8614273812829769,  	 Batch Loss:  0.4971395432949066: 100%|██████████| 113/113 [00:07<00:00, 15.79it/s]


Accuracy:  0.8408333333333333 	Loss:  0.4971395432949066
Epoch:  33


Running Accuracy: 0.8630935851152458,  	 Batch Loss:  0.23867680132389069: 100%|██████████| 113/113 [00:07<00:00, 15.89it/s]


Accuracy:  0.8591666666666666 	Loss:  0.23867680132389069
Epoch:  34


Running Accuracy: 0.8600388780894196,  	 Batch Loss:  0.5118176937103271: 100%|██████████| 113/113 [00:07<00:00, 16.00it/s]


Accuracy:  0.8475 	Loss:  0.5118176937103271
Epoch:  35


Running Accuracy: 0.8589280755345737,  	 Batch Loss:  0.2312222421169281: 100%|██████████| 113/113 [00:07<00:00, 15.90it/s] 


Accuracy:  0.8525 	Loss:  0.2312222421169281
Epoch:  36


Running Accuracy: 0.8569841710635935,  	 Batch Loss:  0.22711515426635742: 100%|██████████| 113/113 [00:07<00:00, 15.78it/s]


Accuracy:  0.8633333333333333 	Loss:  0.22711515426635742
Epoch:  37


Running Accuracy: 0.855595667870036,  	 Batch Loss:  0.3198682963848114: 100%|██████████| 113/113 [00:07<00:00, 16.11it/s] 


Accuracy:  0.86 	Loss:  0.3198682963848114
Epoch:  38


Running Accuracy: 0.8611496806442654,  	 Batch Loss:  0.6181073188781738: 100%|██████████| 113/113 [00:07<00:00, 15.76it/s] 


Accuracy:  0.8491666666666666 	Loss:  0.6181073188781738
Epoch:  39


Running Accuracy: 0.8658705915023605,  	 Batch Loss:  0.42275357246398926: 100%|██████████| 113/113 [00:07<00:00, 15.86it/s]


Accuracy:  0.8683333333333333 	Loss:  0.42275357246398926
Epoch:  40


Running Accuracy: 0.860316578728131,  	 Batch Loss:  0.6354349255561829: 100%|██████████| 113/113 [00:07<00:00, 15.95it/s] 


Accuracy:  0.8491666666666666 	Loss:  0.6354349255561829
Epoch:  41


Running Accuracy: 0.8730908081088586,  	 Batch Loss:  0.1552182137966156: 100%|██████████| 113/113 [00:07<00:00, 16.13it/s]


Accuracy:  0.8666666666666667 	Loss:  0.1552182137966156
Epoch:  42


Running Accuracy: 0.8717023049153013,  	 Batch Loss:  0.22938454151153564: 100%|██████████| 113/113 [00:07<00:00, 15.94it/s]


Accuracy:  0.8533333333333334 	Loss:  0.22938454151153564
Epoch:  43


Running Accuracy: 0.8764232157733963,  	 Batch Loss:  0.46725088357925415: 100%|██████████| 113/113 [00:07<00:00, 15.87it/s]


Accuracy:  0.8625 	Loss:  0.46725088357925415
Epoch:  44


Running Accuracy: 0.8633712857539573,  	 Batch Loss:  0.2182784378528595: 100%|██████████| 113/113 [00:07<00:00, 15.77it/s]


Accuracy:  0.8641666666666666 	Loss:  0.2182784378528595
Epoch:  45


Running Accuracy: 0.8805887253540683,  	 Batch Loss:  0.2022211253643036: 100%|██████████| 113/113 [00:07<00:00, 15.86it/s]


Accuracy:  0.855 	Loss:  0.2022211253643036
Epoch:  46


Running Accuracy: 0.8750347125798389,  	 Batch Loss:  0.21783506870269775: 100%|██████████| 113/113 [00:07<00:00, 15.89it/s]


Accuracy:  0.87 	Loss:  0.21783506870269775
Epoch:  47


Running Accuracy: 0.878644820883088,  	 Batch Loss:  0.155481219291687: 100%|██████████| 113/113 [00:07<00:00, 15.94it/s]   


Accuracy:  0.8525 	Loss:  0.155481219291687
Epoch:  48


Running Accuracy: 0.8705915023604555,  	 Batch Loss:  0.3211476802825928: 100%|██████████| 113/113 [00:07<00:00, 16.10it/s] 


Accuracy:  0.8708333333333333 	Loss:  0.3211476802825928
Epoch:  49


Running Accuracy: 0.8778117189669536,  	 Batch Loss:  0.3923402428627014: 100%|██████████| 113/113 [00:07<00:00, 16.14it/s] 


Accuracy:  0.8716666666666667 	Loss:  0.3923402428627014
Epoch:  50


Running Accuracy: 0.8800333240766454,  	 Batch Loss:  0.31986019015312195: 100%|██████████| 113/113 [00:07<00:00, 16.09it/s]


Accuracy:  0.8666666666666667 	Loss:  0.31986019015312195


# Issue Here?

In [18]:
# Forward pass function
def forward(input_data):
    # Use the top-level model to decide which specialized model to use
    with torch.no_grad():
        decision = model_top(input_data)
        # Get the index (decision) for each input in the batch
        decisions = torch.argmax(decision, dim=1)  # This returns a tensor of decisions for each input in the batch
        print(decisions)

    # Initialize a batch of target tensors
    batch_size = input_data.size(0)
    target_tensor = torch.zeros(batch_size, 12).to(device)  # Initialize the target tensor for the entire batch

    for i in range(batch_size):
        # Route the input to the correct specialized model based on the decision for each element
        if decisions[i] == 0:
            output_tensor = model_1_4(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 12)  # Ensure it doesn't exceed target tensor size
            target_tensor[i, :4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor
        
        elif decisions[i] == 1:
            output_tensor = model_5_8(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 7)  # Ensure it fits in the available space (12-5=7)
            target_tensor[i, 4:4+4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor
            
        elif decisions[i] == 2:
            output_tensor = model_9_12(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 3)  # Ensure it fits in the available space (12-9=3)
            target_tensor[i, 8:8+4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor

    return target_tensor



In [19]:

root_dir = r'Data\Final Testing Dataset\Final Testing Dataset'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [20]:
correct = 0
total = 0
total_loss=0
with torch.no_grad():
    for (images, labels) in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = forward(images)
        val_loss = nn.CrossEntropyLoss()(outputs, labels)
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(test_dataloader))

tensor([1, 0, 0, 1, 0, 0, 1, 2, 0, 1, 1, 0, 1, 0, 0, 2, 1, 2, 0, 2, 0, 2, 1, 2,
        0, 2, 1, 0, 0, 0, 0, 2], device='cuda:0')
tensor([2, 0, 1, 0, 2, 0, 2, 1, 1, 2, 0, 0, 2, 2, 1, 0, 2, 1, 2, 0, 2, 0, 2, 0,
        1, 1, 1, 2, 1, 0, 1, 2], device='cuda:0')
tensor([2, 1, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 0, 0, 2, 2, 0, 0, 2, 2, 1, 1,
        1, 0, 1, 2, 0, 2, 0, 0], device='cuda:0')
tensor([1, 0, 2, 2, 0, 0, 0, 0, 1, 2, 0, 2, 0, 2, 1, 1, 2, 1, 1, 1, 0, 0, 1, 2,
        1, 1, 0, 1, 2, 1, 2, 1], device='cuda:0')
tensor([2, 1, 1, 1, 1, 0, 1, 1, 2, 0, 1, 2, 1, 2, 2, 0, 1, 2, 2, 1, 0, 1, 2, 2,
        1, 1, 0, 2, 0, 1, 2, 2], device='cuda:0')
tensor([2, 2, 2, 0, 0, 2, 0, 2, 1, 2, 1, 1, 1, 2, 1, 0, 0, 2, 1, 2, 1, 1, 2, 1,
        2, 0, 2, 0, 2, 1, 1, 2], device='cuda:0')
tensor([1, 2, 1, 2, 1, 2, 2, 0, 1, 0, 0, 2, 2, 0, 1, 0, 1, 2, 2, 0, 0, 0, 0, 1,
        2, 0, 2, 2, 1, 2, 0, 1], device='cuda:0')
tensor([1, 1, 0, 1, 2, 1, 0, 1, 0, 2, 2, 2, 0, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2,
        0,

In [21]:
predicted

tensor([6, 0, 3, 0, 4, 3, 7, 5, 0, 7, 5, 5, 0, 2, 7, 4], device='cuda:0')

In [22]:
labels

tensor([ 8,  0,  5,  0,  7,  6, 11,  8,  0, 11,  8,  2,  0,  7, 10,  7],
       device='cuda:0')

: 