# Pytorch Implementation

In [2]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [3]:
from modules.CNN import CNN

In [4]:
from tqdm import tqdm

## Data

In [5]:
data_1_4 = fr"\Data\1_4"
data_5_8 = fr"\Data\5_8"
data_9_12= fr"\Data\9_12"

data_1_12= fr"Data\Final Testing Dataset\Final Testing Dataset"

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Resize((40, 40)).to(device),
    transforms.RandomRotation(45).to(device)
])

## Model

### 1_12

In [7]:
root_dir = r'Data\Final Testing Dataset\Final Testing Dataset'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)

In [8]:
model = CNN(12).to(device)

In [9]:
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

In [10]:

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [11]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.23271313524021106,  	 Batch Loss:  1.794292688369751: 100%|██████████| 113/113 [00:09<00:00, 11.94it/s]


Accuracy:  0.3416666666666667 	Loss:  1.794292688369751
Epoch:  2


Running Accuracy: 0.35018050541516244,  	 Batch Loss:  1.9072198867797852: 100%|██████████| 113/113 [00:07<00:00, 15.30it/s]


Accuracy:  0.38333333333333336 	Loss:  1.9072198867797852
Epoch:  3


Running Accuracy: 0.40238822549291864,  	 Batch Loss:  1.4016845226287842: 100%|██████████| 113/113 [00:07<00:00, 15.42it/s]


Accuracy:  0.3625 	Loss:  1.4016845226287842
Epoch:  4


Running Accuracy: 0.47875590113857264,  	 Batch Loss:  1.2770018577575684: 100%|██████████| 113/113 [00:07<00:00, 15.43it/s]


Accuracy:  0.49083333333333334 	Loss:  1.2770018577575684
Epoch:  5


Running Accuracy: 0.542349347403499,  	 Batch Loss:  1.1132389307022095: 100%|██████████| 113/113 [00:07<00:00, 15.32it/s] 


Accuracy:  0.5241666666666667 	Loss:  1.1132389307022095
Epoch:  6


Running Accuracy: 0.5417939461260761,  	 Batch Loss:  1.3848941326141357: 100%|██████████| 113/113 [00:07<00:00, 15.46it/s]


Accuracy:  0.565 	Loss:  1.3848941326141357
Epoch:  7


Running Accuracy: 0.6064981949458483,  	 Batch Loss:  0.9637665748596191: 100%|██████████| 113/113 [00:07<00:00, 15.60it/s]


Accuracy:  0.5266666666666666 	Loss:  0.9637665748596191
Epoch:  8


Running Accuracy: 0.5923354623715634,  	 Batch Loss:  0.8349295258522034: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s]


Accuracy:  0.5808333333333333 	Loss:  0.8349295258522034
Epoch:  9


Running Accuracy: 0.6264926409330741,  	 Batch Loss:  0.8002030849456787: 100%|██████████| 113/113 [00:07<00:00, 15.46it/s]


Accuracy:  0.6133333333333333 	Loss:  0.8002030849456787
Epoch:  10


Running Accuracy: 0.6273257428492085,  	 Batch Loss:  0.9647778868675232: 100%|██████████| 113/113 [00:07<00:00, 15.71it/s]


Accuracy:  0.6008333333333333 	Loss:  0.9647778868675232
Epoch:  11


Running Accuracy: 0.6525965009719522,  	 Batch Loss:  0.6207951903343201: 100%|██████████| 113/113 [00:07<00:00, 15.69it/s]


Accuracy:  0.5558333333333333 	Loss:  0.6207951903343201
Epoch:  12


Running Accuracy: 0.6600944182171619,  	 Batch Loss:  0.8127725720405579: 100%|██████████| 113/113 [00:07<00:00, 15.72it/s]


Accuracy:  0.63 	Loss:  0.8127725720405579
Epoch:  13


Running Accuracy: 0.6737017495140238,  	 Batch Loss:  0.6031840443611145: 100%|██████████| 113/113 [00:07<00:00, 15.69it/s]


Accuracy:  0.6375 	Loss:  0.6031840443611145
Epoch:  14


Running Accuracy: 0.6692585392946404,  	 Batch Loss:  1.0582355260849: 100%|██████████| 113/113 [00:07<00:00, 15.98it/s]  


Accuracy:  0.6508333333333334 	Loss:  1.0582355260849
Epoch:  15


Running Accuracy: 0.6842543737850597,  	 Batch Loss:  0.6255459785461426: 100%|██████████| 113/113 [00:07<00:00, 15.67it/s]


Accuracy:  0.6116666666666667 	Loss:  0.6255459785461426
Epoch:  16


Running Accuracy: 0.690919189114135,  	 Batch Loss:  0.8226152062416077: 100%|██████████| 113/113 [00:07<00:00, 15.70it/s] 


Accuracy:  0.6391666666666667 	Loss:  0.8226152062416077
Epoch:  17


Running Accuracy: 0.7023049153013052,  	 Batch Loss:  0.4828568696975708: 100%|██████████| 113/113 [00:07<00:00, 15.64it/s]


Accuracy:  0.6316666666666667 	Loss:  0.4828568696975708
Epoch:  18


Running Accuracy: 0.68619827825604,  	 Batch Loss:  0.7510051727294922: 100%|██████████| 113/113 [00:07<00:00, 15.82it/s]  


Accuracy:  0.6408333333333334 	Loss:  0.7510051727294922
Epoch:  19


Running Accuracy: 0.7081366287142461,  	 Batch Loss:  0.771457850933075: 100%|██████████| 113/113 [00:07<00:00, 15.53it/s] 


Accuracy:  0.6675 	Loss:  0.771457850933075
Epoch:  20


Running Accuracy: 0.7136906414884754,  	 Batch Loss:  0.5629341006278992: 100%|██████████| 113/113 [00:07<00:00, 15.60it/s]


Accuracy:  0.6775 	Loss:  0.5629341006278992
Epoch:  21


Running Accuracy: 0.7136906414884754,  	 Batch Loss:  0.5644461512565613: 100%|██████████| 113/113 [00:07<00:00, 15.84it/s] 


Accuracy:  0.6658333333333334 	Loss:  0.5644461512565613
Epoch:  22


Running Accuracy: 0.7250763676756456,  	 Batch Loss:  0.763196587562561: 100%|██████████| 113/113 [00:07<00:00, 15.63it/s] 


Accuracy:  0.6466666666666666 	Loss:  0.763196587562561
Epoch:  23


Running Accuracy: 0.712857539572341,  	 Batch Loss:  0.912513256072998: 100%|██████████| 113/113 [00:07<00:00, 15.83it/s]  


Accuracy:  0.6125 	Loss:  0.912513256072998
Epoch:  24


Running Accuracy: 0.7200777561788392,  	 Batch Loss:  0.9176917672157288: 100%|██████████| 113/113 [00:07<00:00, 15.90it/s] 


Accuracy:  0.6816666666666666 	Loss:  0.9176917672157288
Epoch:  25


Running Accuracy: 0.7186892529852819,  	 Batch Loss:  0.7490048408508301: 100%|██████████| 113/113 [00:07<00:00, 15.67it/s]


Accuracy:  0.675 	Loss:  0.7490048408508301
Epoch:  26


Running Accuracy: 0.7356289919466815,  	 Batch Loss:  0.827994167804718: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s] 


Accuracy:  0.6791666666666667 	Loss:  0.827994167804718
Epoch:  27


Running Accuracy: 0.7322965842821438,  	 Batch Loss:  0.7915317416191101: 100%|██████████| 113/113 [00:07<00:00, 15.61it/s]


Accuracy:  0.6866666666666666 	Loss:  0.7915317416191101
Epoch:  28


Running Accuracy: 0.7409053040821993,  	 Batch Loss:  0.820891797542572: 100%|██████████| 113/113 [00:07<00:00, 15.52it/s] 


Accuracy:  0.6475 	Loss:  0.820891797542572
Epoch:  29


Running Accuracy: 0.7375728964176618,  	 Batch Loss:  0.5408692359924316: 100%|██████████| 113/113 [00:07<00:00, 15.48it/s]


Accuracy:  0.6908333333333333 	Loss:  0.5408692359924316
Epoch:  30


Running Accuracy: 0.7386836989725076,  	 Batch Loss:  0.7079589366912842: 100%|██████████| 113/113 [00:07<00:00, 15.32it/s]


Accuracy:  0.6616666666666666 	Loss:  0.7079589366912842
Epoch:  31


Running Accuracy: 0.7439600111080256,  	 Batch Loss:  0.8444207310676575: 100%|██████████| 113/113 [00:07<00:00, 15.50it/s]


Accuracy:  0.6916666666666667 	Loss:  0.8444207310676575
Epoch:  32


Running Accuracy: 0.744237711746737,  	 Batch Loss:  0.9034938812255859: 100%|██████████| 113/113 [00:07<00:00, 15.41it/s] 


Accuracy:  0.7033333333333334 	Loss:  0.9034938812255859
Epoch:  33


Running Accuracy: 0.7517356289919467,  	 Batch Loss:  0.5437516570091248: 100%|██████████| 113/113 [00:07<00:00, 15.75it/s] 


Accuracy:  0.7 	Loss:  0.5437516570091248
Epoch:  34


Running Accuracy: 0.7503471257983894,  	 Batch Loss:  0.4798906445503235: 100%|██████████| 113/113 [00:07<00:00, 15.24it/s]


Accuracy:  0.7083333333333334 	Loss:  0.4798906445503235
Epoch:  35


Running Accuracy: 0.7642321577339628,  	 Batch Loss:  0.5668668746948242: 100%|██████████| 113/113 [00:07<00:00, 15.61it/s]


Accuracy:  0.6891666666666667 	Loss:  0.5668668746948242
Epoch:  36


Running Accuracy: 0.7572896417661761,  	 Batch Loss:  0.762477457523346: 100%|██████████| 113/113 [00:07<00:00, 15.73it/s] 


Accuracy:  0.7066666666666667 	Loss:  0.762477457523346
Epoch:  37


Running Accuracy: 0.7572896417661761,  	 Batch Loss:  0.8642481565475464: 100%|██████████| 113/113 [00:07<00:00, 15.06it/s] 


Accuracy:  0.66 	Loss:  0.8642481565475464
Epoch:  38


Running Accuracy: 0.7497917245209664,  	 Batch Loss:  0.48060479760169983: 100%|██████████| 113/113 [00:07<00:00, 15.79it/s]


Accuracy:  0.6766666666666666 	Loss:  0.48060479760169983
Epoch:  39


Running Accuracy: 0.7575673424048875,  	 Batch Loss:  0.25332507491111755: 100%|██████████| 113/113 [00:07<00:00, 15.82it/s]


Accuracy:  0.6916666666666667 	Loss:  0.25332507491111755
Epoch:  40


Running Accuracy: 0.7745070813662871,  	 Batch Loss:  0.5016249418258667: 100%|██████████| 113/113 [00:07<00:00, 15.96it/s]


Accuracy:  0.6775 	Loss:  0.5016249418258667
Epoch:  41


Running Accuracy: 0.7692307692307693,  	 Batch Loss:  0.5920419692993164: 100%|██████████| 113/113 [00:07<00:00, 15.92it/s]


Accuracy:  0.6933333333333334 	Loss:  0.5920419692993164
Epoch:  42


Running Accuracy: 0.7711746737017495,  	 Batch Loss:  0.558148205280304: 100%|██████████| 113/113 [00:07<00:00, 15.75it/s] 


Accuracy:  0.6275 	Loss:  0.558148205280304
Epoch:  43


Running Accuracy: 0.7706192724243266,  	 Batch Loss:  0.4317047894001007: 100%|██████████| 113/113 [00:07<00:00, 15.78it/s]


Accuracy:  0.6875 	Loss:  0.4317047894001007
Epoch:  44


Running Accuracy: 0.7908914190502638,  	 Batch Loss:  0.532651424407959: 100%|██████████| 113/113 [00:07<00:00, 15.89it/s] 


Accuracy:  0.7333333333333333 	Loss:  0.532651424407959
Epoch:  45


Running Accuracy: 0.7925576228825326,  	 Batch Loss:  0.5363736152648926: 100%|██████████| 113/113 [00:07<00:00, 15.79it/s]


Accuracy:  0.6875 	Loss:  0.5363736152648926
Epoch:  46


Running Accuracy: 0.7717300749791725,  	 Batch Loss:  0.48017334938049316: 100%|██████████| 113/113 [00:07<00:00, 15.96it/s]


Accuracy:  0.685 	Loss:  0.48017334938049316
Epoch:  47


Running Accuracy: 0.7947792279922243,  	 Batch Loss:  0.5164469480514526: 100%|██████████| 113/113 [00:07<00:00, 16.10it/s]


Accuracy:  0.7175 	Loss:  0.5164469480514526
Epoch:  48


Running Accuracy: 0.7831158011663427,  	 Batch Loss:  0.5234214067459106: 100%|██████████| 113/113 [00:07<00:00, 15.64it/s] 


Accuracy:  0.7075 	Loss:  0.5234214067459106
Epoch:  49


Running Accuracy: 0.7792279922243821,  	 Batch Loss:  0.7390831112861633: 100%|██████████| 113/113 [00:07<00:00, 15.76it/s]


Accuracy:  0.7083333333333334 	Loss:  0.7390831112861633
Epoch:  50


Running Accuracy: 0.7817272979727853,  	 Batch Loss:  0.6144466996192932: 100%|██████████| 113/113 [00:07<00:00, 15.90it/s]


Accuracy:  0.7075 	Loss:  0.6144466996192932


In [13]:
torch.save(model, '1_12_bats.pth')

### 1_4

In [8]:
root_dir = r'Data\Top_level\1_4'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_1_4 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_1_4.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(25):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_1_4(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_1_4(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


  return F.softmax(self.model(x))
Running Accuracy: 0.2622814321398834,  	 Batch Loss:  1.32346773147583: 100%|██████████| 38/38 [00:07<00:00,  5.31it/s]   


Accuracy:  0.4675 	Loss:  1.32346773147583
Epoch:  2


Running Accuracy: 0.4163197335553705,  	 Batch Loss:  1.2542829513549805: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s] 


Accuracy:  0.51 	Loss:  1.2542829513549805
Epoch:  3


Running Accuracy: 0.5420482930890924,  	 Batch Loss:  1.0510647296905518: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s]


Accuracy:  0.5275 	Loss:  1.0510647296905518
Epoch:  4


Running Accuracy: 0.5845129059117402,  	 Batch Loss:  1.2312097549438477: 100%|██████████| 38/38 [00:07<00:00,  5.39it/s]


Accuracy:  0.57 	Loss:  1.2312097549438477
Epoch:  5


Running Accuracy: 0.6194837635303914,  	 Batch Loss:  1.050695538520813: 100%|██████████| 38/38 [00:07<00:00,  5.40it/s] 


Accuracy:  0.6225 	Loss:  1.050695538520813
Epoch:  6


Running Accuracy: 0.6161532056619484,  	 Batch Loss:  1.101586937904358: 100%|██████████| 38/38 [00:07<00:00,  5.35it/s] 


Accuracy:  0.6575 	Loss:  1.101586937904358
Epoch:  7


Running Accuracy: 0.6502914238134888,  	 Batch Loss:  1.0145056247711182: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.655 	Loss:  1.0145056247711182
Epoch:  8


Running Accuracy: 0.6477935054121565,  	 Batch Loss:  1.040734887123108: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s] 


Accuracy:  0.6 	Loss:  1.040734887123108
Epoch:  9


Running Accuracy: 0.6536219816819318,  	 Batch Loss:  0.9781327843666077: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s]


Accuracy:  0.65 	Loss:  0.9781327843666077
Epoch:  10


Running Accuracy: 0.6769358867610324,  	 Batch Loss:  0.9463711977005005: 100%|██████████| 38/38 [00:06<00:00,  5.59it/s]


Accuracy:  0.7075 	Loss:  0.9463711977005005
Epoch:  11


Running Accuracy: 0.6960865945045795,  	 Batch Loss:  1.1705741882324219: 100%|██████████| 38/38 [00:06<00:00,  5.55it/s]


Accuracy:  0.68 	Loss:  1.1705741882324219
Epoch:  12


Running Accuracy: 0.6835970024979184,  	 Batch Loss:  1.0214309692382812: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]


Accuracy:  0.675 	Loss:  1.0214309692382812
Epoch:  13


Running Accuracy: 0.704412989175687,  	 Batch Loss:  0.9956305027008057: 100%|██████████| 38/38 [00:07<00:00,  5.42it/s] 


Accuracy:  0.705 	Loss:  0.9956305027008057
Epoch:  14


Running Accuracy: 0.7160699417152373,  	 Batch Loss:  1.0701725482940674: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]


Accuracy:  0.7225 	Loss:  1.0701725482940674
Epoch:  15


Running Accuracy: 0.7160699417152373,  	 Batch Loss:  0.9718973636627197: 100%|██████████| 38/38 [00:06<00:00,  5.50it/s]


Accuracy:  0.7525 	Loss:  0.9718973636627197
Epoch:  16


Running Accuracy: 0.7385512073272273,  	 Batch Loss:  1.0330451726913452: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.755 	Loss:  1.0330451726913452
Epoch:  17


Running Accuracy: 0.7435470441298918,  	 Batch Loss:  0.9165744781494141: 100%|██████████| 38/38 [00:06<00:00,  5.55it/s]


Accuracy:  0.76 	Loss:  0.9165744781494141
Epoch:  18


Running Accuracy: 0.7785179017485429,  	 Batch Loss:  0.9889556765556335: 100%|██████████| 38/38 [00:06<00:00,  5.62it/s]


Accuracy:  0.7725 	Loss:  0.9889556765556335
Epoch:  19


Running Accuracy: 0.7577019150707743,  	 Batch Loss:  0.8860249519348145: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.7425 	Loss:  0.8860249519348145
Epoch:  20


Running Accuracy: 0.7560366361365529,  	 Batch Loss:  1.0901321172714233: 100%|██████████| 38/38 [00:07<00:00,  5.39it/s]


Accuracy:  0.7975 	Loss:  1.0901321172714233
Epoch:  21


Running Accuracy: 0.7901748542880933,  	 Batch Loss:  0.9450153112411499: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]


Accuracy:  0.7725 	Loss:  0.9450153112411499
Epoch:  22


Running Accuracy: 0.7801831806827644,  	 Batch Loss:  0.831143856048584: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s] 


Accuracy:  0.7975 	Loss:  0.831143856048584
Epoch:  23


Running Accuracy: 0.7843463780183181,  	 Batch Loss:  0.8808000683784485: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]


Accuracy:  0.7725 	Loss:  0.8808000683784485
Epoch:  24


Running Accuracy: 0.7860116569525396,  	 Batch Loss:  1.0473318099975586: 100%|██████████| 38/38 [00:06<00:00,  5.65it/s]


Accuracy:  0.7875 	Loss:  1.0473318099975586
Epoch:  25


Running Accuracy: 0.7552039966694422,  	 Batch Loss:  1.0470463037490845: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s]


Accuracy:  0.7675 	Loss:  1.0470463037490845


### 5_8

In [9]:
root_dir = r'Data\Top_level\5_8'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_5_8 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_5_8.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(25):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_5_8(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_5_8(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.2675,  	 Batch Loss:  1.3848109245300293: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]             


Accuracy:  0.2475 	Loss:  1.3848109245300293
Epoch:  2


Running Accuracy: 0.2725,  	 Batch Loss:  1.3761368989944458: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]             


Accuracy:  0.395 	Loss:  1.3761368989944458
Epoch:  3


Running Accuracy: 0.3425,  	 Batch Loss:  1.3459969758987427: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]             


Accuracy:  0.4075 	Loss:  1.3459969758987427
Epoch:  4


Running Accuracy: 0.49083333333333334,  	 Batch Loss:  1.2136622667312622: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s]


Accuracy:  0.5325 	Loss:  1.2136622667312622
Epoch:  5


Running Accuracy: 0.5875,  	 Batch Loss:  1.1265952587127686: 100%|██████████| 38/38 [00:07<00:00,  5.42it/s]            


Accuracy:  0.5875 	Loss:  1.1265952587127686
Epoch:  6


Running Accuracy: 0.6325,  	 Batch Loss:  1.111323356628418: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]             


Accuracy:  0.615 	Loss:  1.111323356628418
Epoch:  7


Running Accuracy: 0.665,  	 Batch Loss:  1.0638412237167358: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]             


Accuracy:  0.6525 	Loss:  1.0638412237167358
Epoch:  8


Running Accuracy: 0.6983333333333334,  	 Batch Loss:  1.0877630710601807: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s]


Accuracy:  0.6125 	Loss:  1.0877630710601807
Epoch:  9


Running Accuracy: 0.6725,  	 Batch Loss:  1.0262426137924194: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]            


Accuracy:  0.62 	Loss:  1.0262426137924194
Epoch:  10


Running Accuracy: 0.6941666666666667,  	 Batch Loss:  0.9742281436920166: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.655 	Loss:  0.9742281436920166
Epoch:  11


Running Accuracy: 0.6825,  	 Batch Loss:  0.9959296584129333: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]            


Accuracy:  0.6425 	Loss:  0.9959296584129333
Epoch:  12


Running Accuracy: 0.7066666666666667,  	 Batch Loss:  1.0134239196777344: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.68 	Loss:  1.0134239196777344
Epoch:  13


Running Accuracy: 0.6925,  	 Batch Loss:  0.9824653267860413: 100%|██████████| 38/38 [00:06<00:00,  5.60it/s]            


Accuracy:  0.6275 	Loss:  0.9824653267860413
Epoch:  14


Running Accuracy: 0.7008333333333333,  	 Batch Loss:  1.0830333232879639: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.6825 	Loss:  1.0830333232879639
Epoch:  15


Running Accuracy: 0.735,  	 Batch Loss:  0.9383434057235718: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s]             


Accuracy:  0.665 	Loss:  0.9383434057235718
Epoch:  16


Running Accuracy: 0.7041666666666667,  	 Batch Loss:  1.043535590171814: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s] 


Accuracy:  0.68 	Loss:  1.043535590171814
Epoch:  17


Running Accuracy: 0.7141666666666666,  	 Batch Loss:  0.9736774563789368: 100%|██████████| 38/38 [00:06<00:00,  5.53it/s]


Accuracy:  0.6825 	Loss:  0.9736774563789368
Epoch:  18


Running Accuracy: 0.7408333333333333,  	 Batch Loss:  0.9369277358055115: 100%|██████████| 38/38 [00:06<00:00,  5.50it/s]


Accuracy:  0.695 	Loss:  0.9369277358055115
Epoch:  19


Running Accuracy: 0.7416666666666667,  	 Batch Loss:  0.9741794466972351: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.7 	Loss:  0.9741794466972351
Epoch:  20


Running Accuracy: 0.7191666666666666,  	 Batch Loss:  0.9404433965682983: 100%|██████████| 38/38 [00:06<00:00,  5.57it/s]


Accuracy:  0.7275 	Loss:  0.9404433965682983
Epoch:  21


Running Accuracy: 0.7391666666666666,  	 Batch Loss:  1.112856388092041: 100%|██████████| 38/38 [00:06<00:00,  5.61it/s] 


Accuracy:  0.6625 	Loss:  1.112856388092041
Epoch:  22


Running Accuracy: 0.7391666666666666,  	 Batch Loss:  0.9179673194885254: 100%|██████████| 38/38 [00:06<00:00,  5.48it/s]


Accuracy:  0.7025 	Loss:  0.9179673194885254
Epoch:  23


Running Accuracy: 0.7283333333333334,  	 Batch Loss:  1.1694988012313843: 100%|██████████| 38/38 [00:06<00:00,  5.58it/s]


Accuracy:  0.6975 	Loss:  1.1694988012313843
Epoch:  24


Running Accuracy: 0.7516666666666667,  	 Batch Loss:  0.995733916759491: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s] 


Accuracy:  0.71 	Loss:  0.995733916759491
Epoch:  25


Running Accuracy: 0.7483333333333333,  	 Batch Loss:  0.9428406953811646: 100%|██████████| 38/38 [00:06<00:00,  5.46it/s]


Accuracy:  0.7325 	Loss:  0.9428406953811646


### 9_12

In [10]:
root_dir = r'Data\Top_level\9_12'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_9_12 = CNN(4).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_9_12.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(50):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_9_12(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_9_12(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.2325,  	 Batch Loss:  1.378129482269287: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]              


Accuracy:  0.24 	Loss:  1.378129482269287
Epoch:  2


Running Accuracy: 0.2575,  	 Batch Loss:  1.3771227598190308: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]             


Accuracy:  0.27 	Loss:  1.3771227598190308
Epoch:  3


Running Accuracy: 0.25916666666666666,  	 Batch Loss:  1.3824301958084106: 100%|██████████| 38/38 [00:07<00:00,  5.34it/s]


Accuracy:  0.27 	Loss:  1.3824301958084106
Epoch:  4


Running Accuracy: 0.25166666666666665,  	 Batch Loss:  1.3800312280654907: 100%|██████████| 38/38 [00:07<00:00,  5.36it/s]


Accuracy:  0.3475 	Loss:  1.3800312280654907
Epoch:  5


Running Accuracy: 0.2708333333333333,  	 Batch Loss:  1.3690038919448853: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s] 


Accuracy:  0.2375 	Loss:  1.3690038919448853
Epoch:  6


Running Accuracy: 0.305,  	 Batch Loss:  1.364302158355713: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]               


Accuracy:  0.3725 	Loss:  1.364302158355713
Epoch:  7


Running Accuracy: 0.31166666666666665,  	 Batch Loss:  1.356785774230957: 100%|██████████| 38/38 [00:06<00:00,  5.50it/s] 


Accuracy:  0.3075 	Loss:  1.356785774230957
Epoch:  8


Running Accuracy: 0.3125,  	 Batch Loss:  1.3519880771636963: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]             


Accuracy:  0.32 	Loss:  1.3519880771636963
Epoch:  9


Running Accuracy: 0.37916666666666665,  	 Batch Loss:  1.3241548538208008: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s]


Accuracy:  0.385 	Loss:  1.3241548538208008
Epoch:  10


Running Accuracy: 0.3883333333333333,  	 Batch Loss:  1.3135418891906738: 100%|██████████| 38/38 [00:06<00:00,  5.63it/s] 


Accuracy:  0.395 	Loss:  1.3135418891906738
Epoch:  11


Running Accuracy: 0.44333333333333336,  	 Batch Loss:  1.2794835567474365: 100%|██████████| 38/38 [00:06<00:00,  5.52it/s]


Accuracy:  0.4075 	Loss:  1.2794835567474365
Epoch:  12


Running Accuracy: 0.4216666666666667,  	 Batch Loss:  1.2487001419067383: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s] 


Accuracy:  0.4175 	Loss:  1.2487001419067383
Epoch:  13


Running Accuracy: 0.47,  	 Batch Loss:  1.2395199537277222: 100%|██████████| 38/38 [00:06<00:00,  5.54it/s]               


Accuracy:  0.3875 	Loss:  1.2395199537277222
Epoch:  14


Running Accuracy: 0.5258333333333334,  	 Batch Loss:  1.2008579969406128: 100%|██████████| 38/38 [00:06<00:00,  5.56it/s] 


Accuracy:  0.4575 	Loss:  1.2008579969406128
Epoch:  15


Running Accuracy: 0.49083333333333334,  	 Batch Loss:  1.2824417352676392: 100%|██████████| 38/38 [00:06<00:00,  5.45it/s]


Accuracy:  0.4425 	Loss:  1.2824417352676392
Epoch:  16


Running Accuracy: 0.5325,  	 Batch Loss:  1.3528386354446411: 100%|██████████| 38/38 [00:07<00:00,  5.35it/s]            


Accuracy:  0.495 	Loss:  1.3528386354446411
Epoch:  17


Running Accuracy: 0.5325,  	 Batch Loss:  1.2128183841705322: 100%|██████████| 38/38 [00:06<00:00,  5.45it/s]            


Accuracy:  0.5225 	Loss:  1.2128183841705322
Epoch:  18


Running Accuracy: 0.5216666666666666,  	 Batch Loss:  1.2436866760253906: 100%|██████████| 38/38 [00:06<00:00,  5.47it/s]


Accuracy:  0.5075 	Loss:  1.2436866760253906
Epoch:  19


Running Accuracy: 0.545,  	 Batch Loss:  1.2631099224090576: 100%|██████████| 38/38 [00:06<00:00,  5.49it/s]             


Accuracy:  0.4625 	Loss:  1.2631099224090576
Epoch:  20


Running Accuracy: 0.5633333333333334,  	 Batch Loss:  1.0647560358047485: 100%|██████████| 38/38 [00:06<00:00,  5.50it/s]


Accuracy:  0.505 	Loss:  1.0647560358047485
Epoch:  21


Running Accuracy: 0.5533333333333333,  	 Batch Loss:  1.2174129486083984: 100%|██████████| 38/38 [00:06<00:00,  5.44it/s]


Accuracy:  0.5375 	Loss:  1.2174129486083984
Epoch:  22


  0%|          | 0/38 [00:03<?, ?it/s]


KeyboardInterrupt: 

### Top

In [11]:
root_dir = r'Data\Top_level'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
model_top = CNN(3).to(device)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_top.parameters(), lr=0.0001)

length = len(train_dataloader)
for epoch in range(25):  # loop over the dataset multiple times
    print("Epoch: ", epoch+1)
    pbar = tqdm(train_dataloader)
    running_loss = 0.0
    correct = 0
    total_seen = 0
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)
        # total_seen += labels.size(0)

        optimizer.zero_grad()

        outputs = model_top(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        # print(labels.shape, outputs)
        _, predicted = torch.max(outputs.data, 1)
        # print(predicted.size(0))
        total_seen+=predicted.size(0)
        correct += (predicted == labels).sum().item()
        pbar.set_description(f"Running Accuracy: { correct/total_seen},  \t Batch Loss:  {loss}")

    # Testing the network
    correct = 0
    total = 0
    total_loss=0
    with torch.no_grad():
        for (images, labels) in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model_top(images)
            val_loss = nn.CrossEntropyLoss()(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(val_dataloader))

Epoch:  1


Running Accuracy: 0.43432379894473755,  	 Batch Loss:  1.0611133575439453: 100%|██████████| 113/113 [00:08<00:00, 13.66it/s]


Accuracy:  0.605 	Loss:  1.0611133575439453
Epoch:  2


Running Accuracy: 0.6392668703138017,  	 Batch Loss:  0.8022007942199707: 100%|██████████| 113/113 [00:07<00:00, 14.36it/s]


Accuracy:  0.71 	Loss:  0.8022007942199707
Epoch:  3


Running Accuracy: 0.7242432657595113,  	 Batch Loss:  0.7763104438781738: 100%|██████████| 113/113 [00:07<00:00, 14.66it/s]


Accuracy:  0.7466666666666667 	Loss:  0.7763104438781738
Epoch:  4


Running Accuracy: 0.766176062204943,  	 Batch Loss:  0.851377546787262: 100%|██████████| 113/113 [00:07<00:00, 14.67it/s]  


Accuracy:  0.7816666666666666 	Loss:  0.851377546787262
Epoch:  5


Running Accuracy: 0.7906137184115524,  	 Batch Loss:  0.6561484336853027: 100%|██████████| 113/113 [00:07<00:00, 15.07it/s]


Accuracy:  0.7966666666666666 	Loss:  0.6561484336853027
Epoch:  6


Running Accuracy: 0.7945015273535129,  	 Batch Loss:  0.7398011684417725: 100%|██████████| 113/113 [00:07<00:00, 15.25it/s]


Accuracy:  0.8116666666666666 	Loss:  0.7398011684417725
Epoch:  7


Running Accuracy: 0.8056095529019717,  	 Batch Loss:  0.8236948847770691: 100%|██████████| 113/113 [00:07<00:00, 15.46it/s]


Accuracy:  0.8133333333333334 	Loss:  0.8236948847770691
Epoch:  8


Running Accuracy: 0.8125520688697584,  	 Batch Loss:  0.6758787035942078: 100%|██████████| 113/113 [00:07<00:00, 15.53it/s]


Accuracy:  0.8233333333333334 	Loss:  0.6758787035942078
Epoch:  9


Running Accuracy: 0.8167175784504305,  	 Batch Loss:  0.7470481991767883: 100%|██████████| 113/113 [00:07<00:00, 15.06it/s]


Accuracy:  0.8366666666666667 	Loss:  0.7470481991767883
Epoch:  10


Running Accuracy: 0.8228269925020828,  	 Batch Loss:  0.6556544303894043: 100%|██████████| 113/113 [00:07<00:00, 15.64it/s]


Accuracy:  0.7366666666666667 	Loss:  0.6556544303894043
Epoch:  11


Running Accuracy: 0.8133851707858928,  	 Batch Loss:  0.8386995792388916: 100%|██████████| 113/113 [00:07<00:00, 15.40it/s]


Accuracy:  0.8025 	Loss:  0.8386995792388916
Epoch:  12


Running Accuracy: 0.8125520688697584,  	 Batch Loss:  0.6504620909690857: 100%|██████████| 113/113 [00:07<00:00, 15.85it/s]


Accuracy:  0.8316666666666667 	Loss:  0.6504620909690857
Epoch:  13


Running Accuracy: 0.8233823937795057,  	 Batch Loss:  0.7177773118019104: 100%|██████████| 113/113 [00:07<00:00, 15.91it/s]


Accuracy:  0.8375 	Loss:  0.7177773118019104
Epoch:  14


Running Accuracy: 0.8136628714246043,  	 Batch Loss:  0.6734763979911804: 100%|██████████| 113/113 [00:07<00:00, 15.73it/s]


Accuracy:  0.8341666666666666 	Loss:  0.6734763979911804
Epoch:  15


Running Accuracy: 0.8239377950569287,  	 Batch Loss:  0.586066722869873: 100%|██████████| 113/113 [00:07<00:00, 15.47it/s] 


KeyboardInterrupt: 

# Issue Here?

In [18]:
# Forward pass function
def forward(input_data):
    # Use the top-level model to decide which specialized model to use
    with torch.no_grad():
        decision = model_top(input_data)
        # Get the index (decision) for each input in the batch
        decisions = torch.argmax(decision, dim=1)  # This returns a tensor of decisions for each input in the batch
        print(decisions)

    # Initialize a batch of target tensors
    batch_size = input_data.size(0)
    target_tensor = torch.zeros(batch_size, 12).to(device)  # Initialize the target tensor for the entire batch

    for i in range(batch_size):
        # Route the input to the correct specialized model based on the decision for each element
        if decisions[i] == 0:
            output_tensor = model_1_4(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 12)  # Ensure it doesn't exceed target tensor size
            target_tensor[i, :4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor
        
        elif decisions[i] == 1:
            output_tensor = model_5_8(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 7)  # Ensure it fits in the available space (12-5=7)
            target_tensor[i, 4:4+4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor
            
        elif decisions[i] == 2:
            output_tensor = model_9_12(input_data[i].unsqueeze(0))  # Process the ith element
            # target_size = min(output_tensor.size(1), 3)  # Ensure it fits in the available space (12-9=3)
            target_tensor[i, 8:8+4] = output_tensor.squeeze(0)[:4]  # Overlay on the target tensor

    return target_tensor



In [19]:

root_dir = r'Data\Final Testing Dataset\Final Testing Dataset'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [20]:
correct = 0
total = 0
total_loss=0
with torch.no_grad():
    for (images, labels) in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = forward(images)
        val_loss = nn.CrossEntropyLoss()(outputs, labels)
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(test_dataloader))

tensor([1, 0, 0, 1, 0, 0, 1, 2, 0, 1, 1, 0, 1, 0, 0, 2, 1, 2, 0, 2, 0, 2, 1, 2,
        0, 2, 1, 0, 0, 0, 0, 2], device='cuda:0')
tensor([2, 0, 1, 0, 2, 0, 2, 1, 1, 2, 0, 0, 2, 2, 1, 0, 2, 1, 2, 0, 2, 0, 2, 0,
        1, 1, 1, 2, 1, 0, 1, 2], device='cuda:0')
tensor([2, 1, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 2, 0, 0, 2, 2, 0, 0, 2, 2, 1, 1,
        1, 0, 1, 2, 0, 2, 0, 0], device='cuda:0')
tensor([1, 0, 2, 2, 0, 0, 0, 0, 1, 2, 0, 2, 0, 2, 1, 1, 2, 1, 1, 1, 0, 0, 1, 2,
        1, 1, 0, 1, 2, 1, 2, 1], device='cuda:0')
tensor([2, 1, 1, 1, 1, 0, 1, 1, 2, 0, 1, 2, 1, 2, 2, 0, 1, 2, 2, 1, 0, 1, 2, 2,
        1, 1, 0, 2, 0, 1, 2, 2], device='cuda:0')
tensor([2, 2, 2, 0, 0, 2, 0, 2, 1, 2, 1, 1, 1, 2, 1, 0, 0, 2, 1, 2, 1, 1, 2, 1,
        2, 0, 2, 0, 2, 1, 1, 2], device='cuda:0')
tensor([1, 2, 1, 2, 1, 2, 2, 0, 1, 0, 0, 2, 2, 0, 1, 0, 1, 2, 2, 0, 0, 0, 0, 1,
        2, 0, 2, 2, 1, 2, 0, 1], device='cuda:0')
tensor([1, 1, 0, 1, 2, 1, 0, 1, 0, 2, 2, 2, 0, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2,
        0,

In [21]:
predicted

tensor([6, 0, 3, 0, 4, 3, 7, 5, 0, 7, 5, 5, 0, 2, 7, 4], device='cuda:0')

# Total

In [26]:
model_1_4

CNN(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Linear(in_features=3200, out_features=512, bias=True)
    (11): Linear(in_features=512, out_features=4, bias=True)
  )
)

In [12]:
class name_model(nn.Module):
    def __init__(self, model, position):
        super().__init__()
        self.model = model
        self.position = position
    def forward(self, x):
        model_output = self.model(x)
        out_tensor = torch.zeros(x.size(0), 12)
        start_idx = self.position*4
        end_idx = start_idx+4
        out_tensor[:, start_idx:end_idx] = model_output
        return out_tensor

In [18]:
class MixtureOfExperts(nn.Module):
    def __init__(self, general_model, specialists):
        super(MixtureOfExperts, self).__init__()
        self.general_model = general_model
        self.specialists = [name_model(model, index) for index, model in enumerate(specialists)]

    def forward(self, x):
        general_output = self.general_model(x) 
        specialist_indices = torch.argmax(general_output, dim=1)
        outputs = []
        
        # Process each input with the corresponding specialist
        for i, idx in enumerate(specialist_indices):
            selected_specialist = self.specialists[idx]  # Select the appropriate specialist
            outputs.append(selected_specialist(x[i].unsqueeze(0)))  # Apply specialist to the input
        
        # Concatenate the outputs to match batch dimension
        return torch.cat(outputs, dim=0)
    


# model_top

moe = MixtureOfExperts(model_top, [model_1_4, model_5_8, model_9_12]).to(device)


random = torch.randn(15, 3, 40, 40).to(device)

moe(random).shape


torch.Size([15, 12])

In [15]:

root_dir = r'Data\Final Testing Dataset\Final Testing Dataset'
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
val_size = int(0.2 * len(dataset))
test_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size - test_size

# Split the dataset into training, validation, and test sets
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [22]:
correct = 0
total = 0
total_loss=0
with torch.no_grad():
    for (images, labels) in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = moe(images).to(device)
        # print(outputs)
        val_loss = nn.CrossEntropyLoss()(outputs, labels)
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print("Accuracy: ", correct/total, "\tLoss: ", total_loss/len(test_dataloader))

  return F.softmax(self.model(x))


tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 4.8999e-01, 2.6391e-01, 1.4859e-01, 9.7509e-02],
        [3.6199e-05, 7.7985e-02, 1.7881e-02, 9.0410e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.1325e-05, 1.1728e-01, 7.5271e-02, 8.0742e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.8355e-05, 9.9357e-01, 4.3522e-04, 5.9761e-03, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 5.3458e-04, 4.8837e-01, 5.1077e-01, 3.2670e-04],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.9781e-01, 2.0871e-03,
         8.1053e-05, 2.0058e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0