In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
from torchvision import transforms

import numpy as np
from tqdm import tqdm
import math

In [2]:
transform = transforms.Compose([
                    transforms.ToTensor(), 
                    # transforms.Normalize((0.5,), (0.5,)),
            ])

In [3]:
def CreateDataLoaders(Option, p_BatchSize):
    if Option == 1:
        train_dataset = datasets.MNIST("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.MNIST("./", train=False, transform = transform, download=True)
    elif Option == 2:
        train_dataset = datasets.FashionMNIST("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.FashionMNIST("./", train=False, transform = transform, download=True)
    
    elif Option == 3:
        train_dataset = datasets.STL10("./", split="train", transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.STL10("./", split="test", transform = transform, download=True)
    
    elif Option == 4:
        train_dataset = datasets.SVHN("./", split="train", transform = transform, download=True)
        print(len(train_dataset))
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), math.ceil(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.SVHN("./", split="test", transform = transform, download=True)
    
    elif Option == 5:
        train_dataset = datasets.CIFAR10("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.CIFAR10("./", train=False, transform = transform, download=True)

    elif Option == 6:
        train_dataset = datasets.CIFAR100("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.CIFAR100("./", train=False, transform = transform, download=True)

    
    Train_DataLoader = torch.utils.data.DataLoader(train_subset, batch_size = p_BatchSize, shuffle = True)
    Val_DataLoader = torch.utils.data.DataLoader(val_subset, batch_size = p_BatchSize, shuffle = True)
    Test_DataLoader = torch.utils.data.DataLoader(test_dataset, batch_size = p_BatchSize, shuffle = True)

    return Train_DataLoader, Val_DataLoader, Test_DataLoader

In [4]:
class MLP_IOCN(nn.Module):
    def __init__(self, InputDim, OutputDim):
        super(MLP_IOCN, self).__init__()
        self.InputDim = InputDim
        self.OutputDim = OutputDim
        self.Linear1 = nn.Linear(InputDim, 800)
        self.Linear2 = nn.Linear(800, 800)
        self.Linear3 = nn.Linear(800, self.OutputDim)
        # self.Linear4 = nn.Linear(800, self.OutputDim)
        self.ActFunc = nn.functional.elu
        self.batch = nn.BatchNorm1d(800)
        # self.SftMax = nn.Softmax

    def forward(self, x):
        x = self.ActFunc(self.batch(self.Linear1(x)))
        x = self.ActFunc(self.batch(self.Linear2(x)))
        # x = self.ActFunc(self.batch(self.Linear3(x)))

        # output = self.SftMax(self.Linear4(x))
        
        return self.Linear3(x)
    
    # def InitWeights(self):
    #     torch.nn.init.uniform_(self.Linear1.weight,-0.5, 0.5)
    #     torch.nn.init.uniform_(self.Linear2.weight,-0.5, 0.5)

In [5]:
model = MLP_IOCN(100, 10)
for name, param in model.named_parameters():
    print("name=", name)
    # print("param=", param)

name= Linear1.weight
name= Linear1.bias
name= Linear2.weight
name= Linear2.bias
name= Linear3.weight
name= Linear3.bias
name= batch.weight
name= batch.bias


In [6]:
def TrainModel(p_model, loss_criteria, Optimizer, device,  p_TrainDL, Gamma = 5):
    TrainCorr = 0
    TotNumOfSamples = 0

    for images, labels in tqdm(p_TrainDL):
        
        for name, param in p_model.named_parameters():
            # if "Linear1" in name or "bias" in name:
            if "Linear1" in name or "bias" in name:
                continue
            
            tmpParam = param.data
            NewTmpParam = torch.exp(tmpParam - Gamma)
            tmpParam = torch.where(tmpParam<0, NewTmpParam, tmpParam)
            param.data = tmpParam

            # temp_weight = param.data
            # # temp_weight[temp_weight < 0] = torch.clamp(temp_weight[temp_weight < 0], min=0)
            # temp_weight[temp_weight < 0] = torch.exp(temp_weight[temp_weight < 0])
            # param.data = temp_weight
            
        Optimizer.zero_grad()

        images = images.reshape(images.shape[0], -1).to(device)
        labels = labels.to(device)

        pred = p_model(images)
        # print(type(pred))
        # print(pred.shape)
        # print(pred)
        # pred=pred.to(torch.float)
        # labels=labels.to(torch.float)
        predClass = torch.max(pred.data, 1)[1]
        # predClass = np.argmax(pred, axis=-1)
        # predClass = pred.max(1)
        # print("Pred", predClass)
        # print(type(labels))
        # print("labels", labels)

        # predClass=predClass.to(torch.float)
        loss = loss_criteria(pred, labels)
        TrainCorr += (predClass==labels).sum()
        TotNumOfSamples += len(labels)

        loss.backward()
        Optimizer.step()
    
    return TrainCorr/TotNumOfSamples

def EvaluateModel(p_model, p_loader, device):
    Correct = 0
    TotalNumOfSamples = 0

    for images, labels in tqdm(p_loader):
        images = images.reshape(images.shape[0], -1).to(device)
        labels = labels.to(device)

        pred = p_model(images)
        predClass = torch.max(pred.data, 1)[1]
        Correct += (predClass==labels).sum()
        TotalNumOfSamples += len(labels)
        # loss = loss_criteria(predClass, labels)
        # Accuracy += loss.cpu().data.item()

    return Correct/TotalNumOfSamples

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 100

In [None]:
# -------------------------------------------- MNIST Dataset -----------------------------------------------------------
Option = 1
Train_DataLoader, Val_DataLoader, Test_DataLoader = CreateDataLoaders(Option, 64)
ModelName = "Model_MLP_IOCN_MNIST.pt"

model = MLP_IOCN(784, 10).to(device)
loss_criteria = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(model.parameters(), lr=0.0001)

ValAccuracy = 0

saved = False
if saved == True:
    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)
else:
    model.train()
    Train_Accuracy = 0

    for e in range(EPOCHS):
        Train_Accuracy = TrainModel(model, loss_criteria, AdamOpt, device, Train_DataLoader)
        val_acc = EvaluateModel(model, Val_DataLoader, device)

        print("EPOCH - ", e+1, ". Train Accuracy = ", Train_Accuracy.cpu().item(), ", Validation Accuracy = ", val_acc.cpu().item())
        
        if val_acc.cpu().item() > ValAccuracy:
            print("Model Re-Saved")
            ValAccuracy = val_acc.cpu().item()
            torch.save(model, ModelName)


    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)

Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

100%|██████████| 797/797 [00:08<00:00, 89.36it/s]
100%|██████████| 141/141 [00:01<00:00, 95.96it/s]


EPOCH -  1 . Train Accuracy =  0.8982353210449219 , Validation Accuracy =  0.9318888783454895
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.82it/s]
100%|██████████| 141/141 [00:01<00:00, 109.57it/s]


EPOCH -  2 . Train Accuracy =  0.9434705972671509 , Validation Accuracy =  0.9486666917800903
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 90.07it/s]
100%|██████████| 141/141 [00:01<00:00, 114.13it/s]


EPOCH -  3 . Train Accuracy =  0.9573529362678528 , Validation Accuracy =  0.9545555710792542
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 91.84it/s]
100%|██████████| 141/141 [00:01<00:00, 83.11it/s]


EPOCH -  4 . Train Accuracy =  0.9648823738098145 , Validation Accuracy =  0.9594444632530212
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 95.78it/s]
100%|██████████| 141/141 [00:01<00:00, 111.64it/s]


EPOCH -  5 . Train Accuracy =  0.9710980653762817 , Validation Accuracy =  0.9611111283302307
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 87.88it/s]
100%|██████████| 141/141 [00:01<00:00, 113.38it/s]


EPOCH -  6 . Train Accuracy =  0.9742745161056519 , Validation Accuracy =  0.9633333683013916
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.28it/s]
100%|██████████| 141/141 [00:01<00:00, 111.44it/s]


EPOCH -  7 . Train Accuracy =  0.9772353172302246 , Validation Accuracy =  0.9623333215713501


100%|██████████| 797/797 [00:09<00:00, 88.44it/s]
100%|██████████| 141/141 [00:01<00:00, 111.53it/s]


EPOCH -  8 . Train Accuracy =  0.980470597743988 , Validation Accuracy =  0.9676666855812073
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 88.73it/s]
100%|██████████| 141/141 [00:01<00:00, 111.33it/s]


EPOCH -  9 . Train Accuracy =  0.9823921918869019 , Validation Accuracy =  0.9632222056388855


100%|██████████| 797/797 [00:08<00:00, 92.39it/s]
100%|██████████| 141/141 [00:01<00:00, 87.22it/s] 


EPOCH -  10 . Train Accuracy =  0.9834117889404297 , Validation Accuracy =  0.9671111106872559


100%|██████████| 797/797 [00:08<00:00, 94.56it/s]
100%|██████████| 141/141 [00:01<00:00, 104.93it/s]


EPOCH -  11 . Train Accuracy =  0.9846863150596619 , Validation Accuracy =  0.9688888788223267
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.23it/s]
100%|██████████| 141/141 [00:01<00:00, 114.66it/s]


EPOCH -  12 . Train Accuracy =  0.9867058992385864 , Validation Accuracy =  0.9679999947547913


100%|██████████| 797/797 [00:09<00:00, 88.27it/s]
100%|██████████| 141/141 [00:01<00:00, 108.17it/s]


EPOCH -  13 . Train Accuracy =  0.9879019856452942 , Validation Accuracy =  0.9687777757644653


100%|██████████| 797/797 [00:08<00:00, 88.57it/s]
100%|██████████| 141/141 [00:01<00:00, 113.17it/s]


EPOCH -  14 . Train Accuracy =  0.9881568551063538 , Validation Accuracy =  0.9701111316680908
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.51it/s]
100%|██████████| 141/141 [00:01<00:00, 110.19it/s]


EPOCH -  15 . Train Accuracy =  0.9885294437408447 , Validation Accuracy =  0.9682222604751587


100%|██████████| 797/797 [00:08<00:00, 93.51it/s]
100%|██████████| 141/141 [00:01<00:00, 86.56it/s] 


EPOCH -  16 . Train Accuracy =  0.9910784363746643 , Validation Accuracy =  0.968666672706604


100%|██████████| 797/797 [00:08<00:00, 96.81it/s]
100%|██████████| 141/141 [00:01<00:00, 105.69it/s]


EPOCH -  17 . Train Accuracy =  0.990882396697998 , Validation Accuracy =  0.9665555953979492


100%|██████████| 797/797 [00:08<00:00, 90.15it/s]
100%|██████████| 141/141 [00:01<00:00, 110.81it/s]


EPOCH -  18 . Train Accuracy =  0.9911372661590576 , Validation Accuracy =  0.9681110978126526


100%|██████████| 797/797 [00:08<00:00, 89.24it/s]
100%|██████████| 141/141 [00:01<00:00, 111.25it/s]


EPOCH -  19 . Train Accuracy =  0.9917647242546082 , Validation Accuracy =  0.9717777967453003
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.10it/s]
100%|██████████| 141/141 [00:01<00:00, 114.33it/s]


EPOCH -  20 . Train Accuracy =  0.9913725852966309 , Validation Accuracy =  0.9718888998031616
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 81.87it/s]
100%|██████████| 141/141 [00:01<00:00, 111.66it/s]


EPOCH -  21 . Train Accuracy =  0.9928627610206604 , Validation Accuracy =  0.972000002861023
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 90.34it/s]
100%|██████████| 141/141 [00:01<00:00, 101.78it/s]


EPOCH -  22 . Train Accuracy =  0.9928235411643982 , Validation Accuracy =  0.971666693687439


100%|██████████| 797/797 [00:08<00:00, 95.70it/s] 
100%|██████████| 141/141 [00:01<00:00, 91.98it/s]


EPOCH -  23 . Train Accuracy =  0.9940000176429749 , Validation Accuracy =  0.9700000286102295


100%|██████████| 797/797 [00:08<00:00, 91.13it/s]
100%|██████████| 141/141 [00:01<00:00, 108.68it/s]


EPOCH -  24 . Train Accuracy =  0.9929019808769226 , Validation Accuracy =  0.971666693687439


100%|██████████| 797/797 [00:09<00:00, 87.01it/s]
100%|██████████| 141/141 [00:01<00:00, 113.40it/s]


EPOCH -  25 . Train Accuracy =  0.9945686459541321 , Validation Accuracy =  0.9707778096199036


100%|██████████| 797/797 [00:09<00:00, 87.80it/s]
100%|██████████| 141/141 [00:01<00:00, 110.46it/s]


EPOCH -  26 . Train Accuracy =  0.9939804077148438 , Validation Accuracy =  0.9711111187934875


100%|██████████| 797/797 [00:09<00:00, 88.06it/s]
100%|██████████| 141/141 [00:01<00:00, 111.63it/s]


EPOCH -  27 . Train Accuracy =  0.9940980672836304 , Validation Accuracy =  0.9668889045715332


100%|██████████| 797/797 [00:09<00:00, 88.01it/s]
100%|██████████| 141/141 [00:01<00:00, 111.17it/s]


EPOCH -  28 . Train Accuracy =  0.9940392374992371 , Validation Accuracy =  0.9714444279670715


100%|██████████| 797/797 [00:08<00:00, 95.47it/s]
100%|██████████| 141/141 [00:01<00:00, 80.99it/s]


EPOCH -  29 . Train Accuracy =  0.9946666955947876 , Validation Accuracy =  0.9717777967453003


100%|██████████| 797/797 [00:08<00:00, 93.42it/s]
100%|██████████| 141/141 [00:01<00:00, 110.21it/s]


EPOCH -  30 . Train Accuracy =  0.995784342288971 , Validation Accuracy =  0.9767777919769287
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.44it/s]
100%|██████████| 141/141 [00:01<00:00, 112.67it/s]


EPOCH -  31 . Train Accuracy =  0.9954705834388733 , Validation Accuracy =  0.9717777967453003


100%|██████████| 797/797 [00:09<00:00, 87.86it/s]
100%|██████████| 141/141 [00:01<00:00, 112.44it/s]


EPOCH -  32 . Train Accuracy =  0.9947647452354431 , Validation Accuracy =  0.9711111187934875


100%|██████████| 797/797 [00:08<00:00, 90.25it/s]
100%|██████████| 141/141 [00:01<00:00, 112.44it/s]


EPOCH -  33 . Train Accuracy =  0.9956470727920532 , Validation Accuracy =  0.9734444618225098


100%|██████████| 797/797 [00:08<00:00, 89.54it/s]
100%|██████████| 141/141 [00:01<00:00, 113.87it/s]


EPOCH -  34 . Train Accuracy =  0.9954705834388733 , Validation Accuracy =  0.9721111059188843


100%|██████████| 797/797 [00:08<00:00, 95.26it/s]
100%|██████████| 141/141 [00:01<00:00, 82.45it/s]


EPOCH -  35 . Train Accuracy =  0.9949019551277161 , Validation Accuracy =  0.9721111059188843


100%|██████████| 797/797 [00:08<00:00, 93.26it/s]
100%|██████████| 141/141 [00:01<00:00, 112.78it/s]


EPOCH -  36 . Train Accuracy =  0.9964314103126526 , Validation Accuracy =  0.9740000367164612


100%|██████████| 797/797 [00:09<00:00, 87.96it/s]
100%|██████████| 141/141 [00:01<00:00, 114.01it/s]


EPOCH -  37 . Train Accuracy =  0.996843159198761 , Validation Accuracy =  0.9723333716392517


100%|██████████| 797/797 [00:09<00:00, 87.65it/s]
100%|██████████| 141/141 [00:01<00:00, 107.82it/s]


EPOCH -  38 . Train Accuracy =  0.9966862797737122 , Validation Accuracy =  0.972777783870697


100%|██████████| 797/797 [00:09<00:00, 88.14it/s]
100%|██████████| 141/141 [00:01<00:00, 113.21it/s]


EPOCH -  39 . Train Accuracy =  0.9954705834388733 , Validation Accuracy =  0.9703333377838135


100%|██████████| 797/797 [00:09<00:00, 83.71it/s]
100%|██████████| 141/141 [00:01<00:00, 83.88it/s] 


EPOCH -  40 . Train Accuracy =  0.9957059025764465 , Validation Accuracy =  0.976555585861206


100%|██████████| 797/797 [00:08<00:00, 94.02it/s]
100%|██████████| 141/141 [00:01<00:00, 84.01it/s]


EPOCH -  41 . Train Accuracy =  0.9970000386238098 , Validation Accuracy =  0.9756667017936707


100%|██████████| 797/797 [00:08<00:00, 91.89it/s]
100%|██████████| 141/141 [00:01<00:00, 108.71it/s]


EPOCH -  42 . Train Accuracy =  0.9961764812469482 , Validation Accuracy =  0.9725555777549744


100%|██████████| 797/797 [00:08<00:00, 89.23it/s]
100%|██████████| 141/141 [00:01<00:00, 111.00it/s]


EPOCH -  43 . Train Accuracy =  0.9962941408157349 , Validation Accuracy =  0.9734444618225098


100%|██████████| 797/797 [00:08<00:00, 89.62it/s]
100%|██████████| 141/141 [00:01<00:00, 111.33it/s]


EPOCH -  44 . Train Accuracy =  0.9968039393424988 , Validation Accuracy =  0.9722222089767456


100%|██████████| 797/797 [00:09<00:00, 88.04it/s]
100%|██████████| 141/141 [00:01<00:00, 113.06it/s]


EPOCH -  45 . Train Accuracy =  0.9965490698814392 , Validation Accuracy =  0.9728888869285583


100%|██████████| 797/797 [00:08<00:00, 89.06it/s]
100%|██████████| 141/141 [00:01<00:00, 110.34it/s]


EPOCH -  46 . Train Accuracy =  0.996843159198761 , Validation Accuracy =  0.9738888740539551


100%|██████████| 797/797 [00:08<00:00, 95.62it/s]
100%|██████████| 141/141 [00:01<00:00, 79.43it/s]


EPOCH -  47 . Train Accuracy =  0.9970000386238098 , Validation Accuracy =  0.9769999980926514
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 95.31it/s]
100%|██████████| 141/141 [00:01<00:00, 111.72it/s]


EPOCH -  48 . Train Accuracy =  0.9968823790550232 , Validation Accuracy =  0.968666672706604


100%|██████████| 797/797 [00:08<00:00, 89.86it/s]
100%|██████████| 141/141 [00:01<00:00, 112.84it/s]


EPOCH -  49 . Train Accuracy =  0.9972941279411316 , Validation Accuracy =  0.9735555648803711


100%|██████████| 797/797 [00:08<00:00, 89.25it/s]
100%|██████████| 141/141 [00:01<00:00, 113.61it/s]


EPOCH -  50 . Train Accuracy =  0.9972352981567383 , Validation Accuracy =  0.972444474697113


100%|██████████| 797/797 [00:08<00:00, 89.21it/s]
100%|██████████| 141/141 [00:01<00:00, 112.98it/s]


EPOCH -  51 . Train Accuracy =  0.9969215989112854 , Validation Accuracy =  0.972000002861023


100%|██████████| 797/797 [00:08<00:00, 92.38it/s]
100%|██████████| 141/141 [00:01<00:00, 92.28it/s] 


EPOCH -  52 . Train Accuracy =  0.9974902272224426 , Validation Accuracy =  0.9715555906295776


100%|██████████| 797/797 [00:08<00:00, 96.07it/s]
100%|██████████| 141/141 [00:01<00:00, 90.29it/s]


EPOCH -  53 . Train Accuracy =  0.997039258480072 , Validation Accuracy =  0.9755555391311646


100%|██████████| 797/797 [00:08<00:00, 91.47it/s]
100%|██████████| 141/141 [00:01<00:00, 110.77it/s]


EPOCH -  54 . Train Accuracy =  0.9985294342041016 , Validation Accuracy =  0.9740000367164612


100%|██████████| 797/797 [00:08<00:00, 89.08it/s]
100%|██████████| 141/141 [00:01<00:00, 111.61it/s]


EPOCH -  55 . Train Accuracy =  0.9969608187675476 , Validation Accuracy =  0.9728888869285583


100%|██████████| 797/797 [00:08<00:00, 88.97it/s]
100%|██████████| 141/141 [00:01<00:00, 111.52it/s]


EPOCH -  56 . Train Accuracy =  0.9976274967193604 , Validation Accuracy =  0.9738888740539551


100%|██████████| 797/797 [00:08<00:00, 89.59it/s]
100%|██████████| 141/141 [00:01<00:00, 112.58it/s]


EPOCH -  57 . Train Accuracy =  0.9974902272224426 , Validation Accuracy =  0.9747778177261353


100%|██████████| 797/797 [00:08<00:00, 92.02it/s]
100%|██████████| 141/141 [00:01<00:00, 92.26it/s] 


EPOCH -  58 . Train Accuracy =  0.9973921775817871 , Validation Accuracy =  0.97688889503479


100%|██████████| 797/797 [00:08<00:00, 97.46it/s]
100%|██████████| 141/141 [00:01<00:00, 103.18it/s]


EPOCH -  59 . Train Accuracy =  0.9972745180130005 , Validation Accuracy =  0.976111114025116


100%|██████████| 797/797 [00:09<00:00, 83.98it/s]
100%|██████████| 141/141 [00:01<00:00, 112.90it/s]


EPOCH -  60 . Train Accuracy =  0.9976274967193604 , Validation Accuracy =  0.9740000367164612


100%|██████████| 797/797 [00:08<00:00, 88.89it/s]
100%|██████████| 141/141 [00:01<00:00, 108.46it/s]


EPOCH -  61 . Train Accuracy =  0.9982745051383972 , Validation Accuracy =  0.9742222428321838


100%|██████████| 797/797 [00:08<00:00, 90.37it/s]
100%|██████████| 141/141 [00:01<00:00, 111.61it/s]


EPOCH -  62 . Train Accuracy =  0.9974117875099182 , Validation Accuracy =  0.9723333716392517


100%|██████████| 797/797 [00:08<00:00, 89.25it/s]
100%|██████████| 141/141 [00:01<00:00, 111.23it/s]


EPOCH -  63 . Train Accuracy =  0.9977843165397644 , Validation Accuracy =  0.976555585861206


100%|██████████| 797/797 [00:08<00:00, 92.94it/s]
100%|██████████| 141/141 [00:01<00:00, 92.19it/s] 


EPOCH -  64 . Train Accuracy =  0.9978039264678955 , Validation Accuracy =  0.9754444360733032


100%|██████████| 797/797 [00:08<00:00, 97.34it/s]
100%|██████████| 141/141 [00:01<00:00, 101.94it/s]


EPOCH -  65 . Train Accuracy =  0.99770587682724 , Validation Accuracy =  0.9734444618225098


100%|██████████| 797/797 [00:08<00:00, 91.40it/s]
100%|██████████| 141/141 [00:01<00:00, 113.82it/s]


EPOCH -  66 . Train Accuracy =  0.997901976108551 , Validation Accuracy =  0.9774444699287415
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 87.54it/s]
100%|██████████| 141/141 [00:01<00:00, 106.62it/s]


EPOCH -  67 . Train Accuracy =  0.9974902272224426 , Validation Accuracy =  0.9753333330154419


100%|██████████| 797/797 [00:09<00:00, 88.14it/s]
100%|██████████| 141/141 [00:01<00:00, 109.29it/s]


EPOCH -  68 . Train Accuracy =  0.9978039264678955 , Validation Accuracy =  0.9744444489479065


100%|██████████| 797/797 [00:08<00:00, 88.56it/s]
100%|██████████| 141/141 [00:01<00:00, 113.73it/s]


EPOCH -  69 . Train Accuracy =  0.9987255334854126 , Validation Accuracy =  0.9754444360733032


100%|██████████| 797/797 [00:08<00:00, 93.59it/s]
100%|██████████| 141/141 [00:01<00:00, 89.71it/s] 


EPOCH -  70 . Train Accuracy =  0.9980588555335999 , Validation Accuracy =  0.9713333249092102


100%|██████████| 797/797 [00:08<00:00, 96.90it/s]
100%|██████████| 141/141 [00:01<00:00, 111.51it/s]


EPOCH -  71 . Train Accuracy =  0.99770587682724 , Validation Accuracy =  0.9737777709960938


100%|██████████| 797/797 [00:08<00:00, 89.63it/s]
100%|██████████| 141/141 [00:01<00:00, 113.09it/s]


EPOCH -  72 . Train Accuracy =  0.9979608058929443 , Validation Accuracy =  0.9754444360733032


100%|██████████| 797/797 [00:08<00:00, 89.67it/s]
100%|██████████| 141/141 [00:01<00:00, 111.36it/s]


EPOCH -  73 . Train Accuracy =  0.9977843165397644 , Validation Accuracy =  0.9758889079093933


100%|██████████| 797/797 [00:08<00:00, 89.31it/s]
100%|██████████| 141/141 [00:01<00:00, 111.82it/s]


EPOCH -  74 . Train Accuracy =  0.9983921647071838 , Validation Accuracy =  0.9746666550636292


100%|██████████| 797/797 [00:08<00:00, 89.58it/s]
100%|██████████| 141/141 [00:01<00:00, 112.34it/s]


EPOCH -  75 . Train Accuracy =  0.9977254867553711 , Validation Accuracy =  0.9747778177261353


100%|██████████| 797/797 [00:08<00:00, 95.71it/s]
100%|██████████| 141/141 [00:01<00:00, 77.84it/s]


EPOCH -  76 . Train Accuracy =  0.9984706044197083 , Validation Accuracy =  0.9760000109672546


100%|██████████| 797/797 [00:08<00:00, 95.92it/s]
100%|██████████| 141/141 [00:01<00:00, 110.73it/s]


EPOCH -  77 . Train Accuracy =  0.9980588555335999 , Validation Accuracy =  0.9738888740539551


100%|██████████| 797/797 [00:09<00:00, 88.48it/s]
100%|██████████| 141/141 [00:01<00:00, 111.79it/s]


EPOCH -  78 . Train Accuracy =  0.9982548952102661 , Validation Accuracy =  0.973111093044281


100%|██████████| 797/797 [00:08<00:00, 90.47it/s]
100%|██████████| 141/141 [00:01<00:00, 111.98it/s]


EPOCH -  79 . Train Accuracy =  0.9984509944915771 , Validation Accuracy =  0.9746666550636292


100%|██████████| 797/797 [00:09<00:00, 83.30it/s]
100%|██████████| 141/141 [00:01<00:00, 112.44it/s]


EPOCH -  80 . Train Accuracy =  0.998431384563446 , Validation Accuracy =  0.9740000367164612


100%|██████████| 797/797 [00:08<00:00, 90.43it/s]
100%|██████████| 141/141 [00:01<00:00, 113.18it/s]


EPOCH -  81 . Train Accuracy =  0.9979215860366821 , Validation Accuracy =  0.9734444618225098


100%|██████████| 797/797 [00:08<00:00, 94.11it/s]
100%|██████████| 141/141 [00:01<00:00, 84.76it/s]


EPOCH -  82 . Train Accuracy =  0.9985294342041016 , Validation Accuracy =  0.9764444828033447


100%|██████████| 797/797 [00:08<00:00, 97.91it/s]
100%|██████████| 141/141 [00:01<00:00, 102.85it/s]


EPOCH -  83 . Train Accuracy =  0.9982745051383972 , Validation Accuracy =  0.9753333330154419


100%|██████████| 797/797 [00:08<00:00, 88.75it/s]
100%|██████████| 141/141 [00:01<00:00, 113.53it/s]


EPOCH -  84 . Train Accuracy =  0.998431384563446 , Validation Accuracy =  0.9752222299575806


100%|██████████| 797/797 [00:09<00:00, 88.36it/s]
100%|██████████| 141/141 [00:01<00:00, 108.90it/s]


EPOCH -  85 . Train Accuracy =  0.998431384563446 , Validation Accuracy =  0.9750000238418579


100%|██████████| 797/797 [00:08<00:00, 88.61it/s]
100%|██████████| 141/141 [00:01<00:00, 112.63it/s]


EPOCH -  86 . Train Accuracy =  0.9982548952102661 , Validation Accuracy =  0.9750000238418579


100%|██████████| 797/797 [00:08<00:00, 90.79it/s]
100%|██████████| 141/141 [00:01<00:00, 111.89it/s]


EPOCH -  87 . Train Accuracy =  0.9985490441322327 , Validation Accuracy =  0.9732222557067871


100%|██████████| 797/797 [00:08<00:00, 96.05it/s]
100%|██████████| 141/141 [00:01<00:00, 79.22it/s]


EPOCH -  88 . Train Accuracy =  0.9982157349586487 , Validation Accuracy =  0.9753333330154419


100%|██████████| 797/797 [00:08<00:00, 96.34it/s]
100%|██████████| 141/141 [00:01<00:00, 113.83it/s]


EPOCH -  89 . Train Accuracy =  0.9989804029464722 , Validation Accuracy =  0.9744444489479065


100%|██████████| 797/797 [00:08<00:00, 89.33it/s]
100%|██████████| 141/141 [00:01<00:00, 114.23it/s]


EPOCH -  90 . Train Accuracy =  0.9987059235572815 , Validation Accuracy =  0.9746666550636292


100%|██████████| 797/797 [00:08<00:00, 89.78it/s]
100%|██████████| 141/141 [00:01<00:00, 107.60it/s]


EPOCH -  91 . Train Accuracy =  0.998607873916626 , Validation Accuracy =  0.9726666808128357


100%|██████████| 797/797 [00:08<00:00, 88.99it/s]
100%|██████████| 141/141 [00:01<00:00, 110.94it/s]


EPOCH -  92 . Train Accuracy =  0.9978627562522888 , Validation Accuracy =  0.9762222170829773


100%|██████████| 797/797 [00:08<00:00, 89.02it/s]
100%|██████████| 141/141 [00:01<00:00, 111.80it/s]


EPOCH -  93 . Train Accuracy =  0.9989019632339478 , Validation Accuracy =  0.9736666679382324


100%|██████████| 797/797 [00:08<00:00, 96.03it/s]
100%|██████████| 141/141 [00:01<00:00, 79.44it/s]


EPOCH -  94 . Train Accuracy =  0.998431384563446 , Validation Accuracy =  0.9718888998031616


100%|██████████| 797/797 [00:08<00:00, 95.60it/s]
100%|██████████| 141/141 [00:01<00:00, 111.51it/s]


EPOCH -  95 . Train Accuracy =  0.9982548952102661 , Validation Accuracy =  0.972777783870697


100%|██████████| 797/797 [00:08<00:00, 89.52it/s]
100%|██████████| 141/141 [00:01<00:00, 108.61it/s]


EPOCH -  96 . Train Accuracy =  0.9982941150665283 , Validation Accuracy =  0.9782222509384155
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.72it/s]
100%|██████████| 141/141 [00:01<00:00, 109.75it/s]


EPOCH -  97 . Train Accuracy =  0.9990000128746033 , Validation Accuracy =  0.9774444699287415


100%|██████████| 797/797 [00:08<00:00, 89.21it/s]
100%|██████████| 141/141 [00:01<00:00, 112.33it/s]


EPOCH -  98 . Train Accuracy =  0.9985882639884949 , Validation Accuracy =  0.9734444618225098


100%|██████████| 797/797 [00:08<00:00, 88.68it/s]
100%|██████████| 141/141 [00:01<00:00, 101.70it/s]


EPOCH -  99 . Train Accuracy =  0.9978627562522888 , Validation Accuracy =  0.9763333201408386


100%|██████████| 797/797 [00:08<00:00, 90.04it/s]
100%|██████████| 141/141 [00:01<00:00, 77.88it/s]


EPOCH -  100 . Train Accuracy =  0.9993921518325806 , Validation Accuracy =  0.9781111478805542


100%|██████████| 797/797 [00:07<00:00, 110.76it/s]


Train Accuracy =  0.999333381652832


100%|██████████| 157/157 [00:01<00:00, 110.57it/s]

Test Accuracy =  0.9781000018119812
Generalization Gap =  0.02123337984085083





In [None]:
saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)

Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

100%|██████████| 797/797 [00:07<00:00, 101.47it/s]


Train Accuracy =  0.9889608025550842


100%|██████████| 157/157 [00:01<00:00, 105.81it/s]

Test Accuracy =  0.9723999500274658
Generalization Gap =  0.016560852527618408





In [None]:
# -------------------------------------------- FMNIST Dataset -----------------------------------------------------------
Option = 2
Train_DataLoader, Val_DataLoader, Test_DataLoader = CreateDataLoaders(Option, 64)
ModelName = "Model_MLP_IOCN_FMNIST.pt"

model = MLP_IOCN(784, 10).to(device)
loss_criteria = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(model.parameters(), lr=0.0001)

ValAccuracy = 0

saved = False
if saved == True:
    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)
else:
    model.train()
    Train_Accuracy = 0

    for e in range(EPOCHS):
        Train_Accuracy = TrainModel(model, loss_criteria, AdamOpt, device, Train_DataLoader)
        val_acc = EvaluateModel(model, Val_DataLoader, device)

        print("EPOCH - ", e+1, ". Train Accuracy = ", Train_Accuracy.cpu().item(), ", Validation Accuracy = ", val_acc.cpu().item())
        
        if val_acc.cpu().item() > ValAccuracy:
            print("Model Re-Saved")
            ValAccuracy = val_acc.cpu().item()
            torch.save(model, ModelName)


    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)

Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

100%|██████████| 797/797 [00:12<00:00, 65.29it/s]
100%|██████████| 141/141 [00:01<00:00, 104.15it/s]


EPOCH -  1 . Train Accuracy =  0.8138039112091064 , Validation Accuracy =  0.8431110978126526
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.36it/s]
100%|██████████| 141/141 [00:01<00:00, 108.30it/s]


EPOCH -  2 . Train Accuracy =  0.8562549352645874 , Validation Accuracy =  0.8551111221313477
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 84.45it/s]
100%|██████████| 141/141 [00:01<00:00, 106.13it/s]


EPOCH -  3 . Train Accuracy =  0.866921603679657 , Validation Accuracy =  0.8611111044883728
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.27it/s]
100%|██████████| 141/141 [00:01<00:00, 110.70it/s]


EPOCH -  4 . Train Accuracy =  0.8758627772331238 , Validation Accuracy =  0.8657777905464172
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.02it/s]
100%|██████████| 141/141 [00:01<00:00, 108.01it/s]


EPOCH -  5 . Train Accuracy =  0.8799020051956177 , Validation Accuracy =  0.8683333396911621
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 83.87it/s]
100%|██████████| 141/141 [00:01<00:00, 101.40it/s]


EPOCH -  6 . Train Accuracy =  0.8861372470855713 , Validation Accuracy =  0.8642222285270691


100%|██████████| 797/797 [00:10<00:00, 77.71it/s]
100%|██████████| 141/141 [00:01<00:00, 106.26it/s]


EPOCH -  7 . Train Accuracy =  0.8909804224967957 , Validation Accuracy =  0.8619999885559082


100%|██████████| 797/797 [00:08<00:00, 88.60it/s]
100%|██████████| 141/141 [00:01<00:00, 78.88it/s]


EPOCH -  8 . Train Accuracy =  0.895451009273529 , Validation Accuracy =  0.8630000352859497


100%|██████████| 797/797 [00:08<00:00, 92.94it/s]
100%|██████████| 141/141 [00:01<00:00, 101.85it/s]


EPOCH -  9 . Train Accuracy =  0.8990588188171387 , Validation Accuracy =  0.8774444460868835
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 87.33it/s]
100%|██████████| 141/141 [00:01<00:00, 108.96it/s]


EPOCH -  10 . Train Accuracy =  0.9016667008399963 , Validation Accuracy =  0.8795555830001831
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 84.90it/s]
100%|██████████| 141/141 [00:01<00:00, 108.85it/s]


EPOCH -  11 . Train Accuracy =  0.9047058820724487 , Validation Accuracy =  0.8756666779518127


100%|██████████| 797/797 [00:09<00:00, 84.58it/s]
100%|██████████| 141/141 [00:01<00:00, 109.75it/s]


EPOCH -  12 . Train Accuracy =  0.9085686206817627 , Validation Accuracy =  0.8741111159324646


100%|██████████| 797/797 [00:09<00:00, 86.05it/s]
100%|██████████| 141/141 [00:01<00:00, 108.40it/s]


EPOCH -  13 . Train Accuracy =  0.911098062992096 , Validation Accuracy =  0.8808888792991638
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.68it/s]
100%|██████████| 141/141 [00:01<00:00, 106.29it/s]


EPOCH -  14 . Train Accuracy =  0.9119411706924438 , Validation Accuracy =  0.8773333430290222


100%|██████████| 797/797 [00:09<00:00, 88.10it/s]
100%|██████████| 141/141 [00:01<00:00, 83.00it/s]


EPOCH -  15 . Train Accuracy =  0.9158627390861511 , Validation Accuracy =  0.8790000081062317


100%|██████████| 797/797 [00:08<00:00, 92.35it/s]
100%|██████████| 141/141 [00:01<00:00, 90.58it/s]


EPOCH -  16 . Train Accuracy =  0.9189019799232483 , Validation Accuracy =  0.8807777762413025


100%|██████████| 797/797 [00:09<00:00, 85.85it/s]
100%|██████████| 141/141 [00:01<00:00, 107.02it/s]


EPOCH -  17 . Train Accuracy =  0.9195882678031921 , Validation Accuracy =  0.871666669845581


100%|██████████| 797/797 [00:09<00:00, 84.80it/s]
100%|██████████| 141/141 [00:01<00:00, 106.72it/s]


EPOCH -  18 . Train Accuracy =  0.92166668176651 , Validation Accuracy =  0.8855555653572083
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.57it/s]
100%|██████████| 141/141 [00:01<00:00, 108.32it/s]


EPOCH -  19 . Train Accuracy =  0.924490213394165 , Validation Accuracy =  0.8847777843475342


100%|██████████| 797/797 [00:09<00:00, 85.73it/s]
100%|██████████| 141/141 [00:01<00:00, 104.00it/s]


EPOCH -  20 . Train Accuracy =  0.9263725876808167 , Validation Accuracy =  0.8852222561836243


100%|██████████| 797/797 [00:09<00:00, 84.06it/s]
100%|██████████| 141/141 [00:01<00:00, 107.69it/s]


EPOCH -  21 . Train Accuracy =  0.9293921589851379 , Validation Accuracy =  0.8817777633666992


100%|██████████| 797/797 [00:09<00:00, 84.25it/s]
100%|██████████| 141/141 [00:01<00:00, 106.55it/s]


EPOCH -  22 . Train Accuracy =  0.9302549362182617 , Validation Accuracy =  0.8772222399711609


100%|██████████| 797/797 [00:08<00:00, 91.54it/s]
100%|██████████| 141/141 [00:01<00:00, 74.40it/s]


EPOCH -  23 . Train Accuracy =  0.9313333630561829 , Validation Accuracy =  0.8825555443763733


100%|██████████| 797/797 [00:08<00:00, 88.94it/s]
100%|██████████| 141/141 [00:01<00:00, 106.25it/s]


EPOCH -  24 . Train Accuracy =  0.9321960806846619 , Validation Accuracy =  0.8772222399711609


100%|██████████| 797/797 [00:10<00:00, 79.60it/s]
100%|██████████| 141/141 [00:01<00:00, 109.11it/s]


EPOCH -  25 . Train Accuracy =  0.9364314079284668 , Validation Accuracy =  0.8844444751739502


100%|██████████| 797/797 [00:09<00:00, 84.45it/s]
100%|██████████| 141/141 [00:01<00:00, 106.49it/s]


EPOCH -  26 . Train Accuracy =  0.9371568560600281 , Validation Accuracy =  0.886555552482605
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 84.07it/s]
100%|██████████| 141/141 [00:01<00:00, 106.54it/s]


EPOCH -  27 . Train Accuracy =  0.9389411807060242 , Validation Accuracy =  0.8815555572509766


100%|██████████| 797/797 [00:09<00:00, 84.07it/s]
100%|██████████| 141/141 [00:01<00:00, 109.10it/s]


EPOCH -  28 . Train Accuracy =  0.9402941465377808 , Validation Accuracy =  0.8827778100967407


100%|██████████| 797/797 [00:09<00:00, 84.27it/s]
100%|██████████| 141/141 [00:01<00:00, 107.87it/s]


EPOCH -  29 . Train Accuracy =  0.9410392642021179 , Validation Accuracy =  0.8856666684150696


100%|██████████| 797/797 [00:08<00:00, 90.58it/s]
100%|██████████| 141/141 [00:01<00:00, 78.63it/s]


EPOCH -  30 . Train Accuracy =  0.9423921704292297 , Validation Accuracy =  0.8835555911064148


100%|██████████| 797/797 [00:08<00:00, 91.50it/s]
100%|██████████| 141/141 [00:01<00:00, 95.21it/s]


EPOCH -  31 . Train Accuracy =  0.9450980424880981 , Validation Accuracy =  0.8856666684150696


100%|██████████| 797/797 [00:09<00:00, 84.76it/s]
100%|██████████| 141/141 [00:01<00:00, 107.79it/s]


EPOCH -  32 . Train Accuracy =  0.9466274976730347 , Validation Accuracy =  0.8858888745307922


100%|██████████| 797/797 [00:09<00:00, 84.63it/s]
100%|██████████| 141/141 [00:01<00:00, 106.49it/s]


EPOCH -  33 . Train Accuracy =  0.9464313983917236 , Validation Accuracy =  0.8777778148651123


100%|██████████| 797/797 [00:09<00:00, 85.47it/s]
100%|██████████| 141/141 [00:01<00:00, 109.16it/s]


EPOCH -  34 . Train Accuracy =  0.9475098252296448 , Validation Accuracy =  0.8849999904632568


100%|██████████| 797/797 [00:09<00:00, 85.34it/s]
100%|██████████| 141/141 [00:01<00:00, 104.63it/s]


EPOCH -  35 . Train Accuracy =  0.9496274590492249 , Validation Accuracy =  0.8815555572509766


100%|██████████| 797/797 [00:09<00:00, 85.55it/s]
100%|██████████| 141/141 [00:01<00:00, 106.98it/s]


EPOCH -  36 . Train Accuracy =  0.949647068977356 , Validation Accuracy =  0.8833333253860474


100%|██████████| 797/797 [00:09<00:00, 82.64it/s]
100%|██████████| 141/141 [00:01<00:00, 96.28it/s] 


EPOCH -  37 . Train Accuracy =  0.9513333439826965 , Validation Accuracy =  0.8853333592414856


100%|██████████| 797/797 [00:09<00:00, 86.42it/s]
100%|██████████| 141/141 [00:01<00:00, 80.77it/s]


EPOCH -  38 . Train Accuracy =  0.953529417514801 , Validation Accuracy =  0.8870000243186951
Model Re-Saved


100%|██████████| 797/797 [00:08<00:00, 89.37it/s]
100%|██████████| 141/141 [00:01<00:00, 85.79it/s]


EPOCH -  39 . Train Accuracy =  0.9539999961853027 , Validation Accuracy =  0.8807777762413025


100%|██████████| 797/797 [00:09<00:00, 85.97it/s]
100%|██████████| 141/141 [00:01<00:00, 103.96it/s]


EPOCH -  40 . Train Accuracy =  0.9551960825920105 , Validation Accuracy =  0.8830000162124634


100%|██████████| 797/797 [00:09<00:00, 83.89it/s]
100%|██████████| 141/141 [00:01<00:00, 107.55it/s]


EPOCH -  41 . Train Accuracy =  0.9565098285675049 , Validation Accuracy =  0.8765555620193481


100%|██████████| 797/797 [00:09<00:00, 84.34it/s]
100%|██████████| 141/141 [00:01<00:00, 106.45it/s]


EPOCH -  42 . Train Accuracy =  0.9579607844352722 , Validation Accuracy =  0.8798888921737671


100%|██████████| 797/797 [00:10<00:00, 77.70it/s]
100%|██████████| 141/141 [00:01<00:00, 105.72it/s]


EPOCH -  43 . Train Accuracy =  0.9578823447227478 , Validation Accuracy =  0.8798888921737671


100%|██████████| 797/797 [00:09<00:00, 85.23it/s]
100%|██████████| 141/141 [00:01<00:00, 105.41it/s]


EPOCH -  44 . Train Accuracy =  0.9591372609138489 , Validation Accuracy =  0.8848888874053955


100%|██████████| 797/797 [00:09<00:00, 84.06it/s]
100%|██████████| 141/141 [00:01<00:00, 109.93it/s]


EPOCH -  45 . Train Accuracy =  0.9590980410575867 , Validation Accuracy =  0.8777778148651123


100%|██████████| 797/797 [00:09<00:00, 84.11it/s]
100%|██████████| 141/141 [00:01<00:00, 107.27it/s]


EPOCH -  46 . Train Accuracy =  0.9605294466018677 , Validation Accuracy =  0.8854444622993469


100%|██████████| 797/797 [00:08<00:00, 90.86it/s]
100%|██████████| 141/141 [00:01<00:00, 77.26it/s]


EPOCH -  47 . Train Accuracy =  0.9609804153442383 , Validation Accuracy =  0.882444441318512


100%|██████████| 797/797 [00:08<00:00, 90.79it/s]
100%|██████████| 141/141 [00:01<00:00, 107.99it/s]


EPOCH -  48 . Train Accuracy =  0.9638431668281555 , Validation Accuracy =  0.8844444751739502


100%|██████████| 797/797 [00:09<00:00, 85.55it/s]
100%|██████████| 141/141 [00:01<00:00, 107.99it/s]


EPOCH -  49 . Train Accuracy =  0.9618039131164551 , Validation Accuracy =  0.8820000290870667


100%|██████████| 797/797 [00:09<00:00, 84.83it/s]
100%|██████████| 141/141 [00:01<00:00, 105.37it/s]


EPOCH -  50 . Train Accuracy =  0.9648823738098145 , Validation Accuracy =  0.8876667022705078
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.62it/s]
100%|██████████| 141/141 [00:01<00:00, 107.30it/s]


EPOCH -  51 . Train Accuracy =  0.9636666774749756 , Validation Accuracy =  0.886222243309021


100%|██████████| 797/797 [00:09<00:00, 84.26it/s]
100%|██████████| 141/141 [00:01<00:00, 105.42it/s]


EPOCH -  52 . Train Accuracy =  0.966372549533844 , Validation Accuracy =  0.8843333721160889


100%|██████████| 797/797 [00:09<00:00, 83.98it/s]
100%|██████████| 141/141 [00:01<00:00, 107.18it/s]


EPOCH -  53 . Train Accuracy =  0.965509831905365 , Validation Accuracy =  0.8848888874053955


100%|██████████| 797/797 [00:09<00:00, 87.81it/s]
100%|██████████| 141/141 [00:01<00:00, 88.01it/s] 


EPOCH -  54 . Train Accuracy =  0.9662941098213196 , Validation Accuracy =  0.883222222328186


100%|██████████| 797/797 [00:08<00:00, 92.34it/s]
100%|██████████| 141/141 [00:01<00:00, 89.95it/s]


EPOCH -  55 . Train Accuracy =  0.9689216017723083 , Validation Accuracy =  0.8847777843475342


100%|██████████| 797/797 [00:09<00:00, 88.19it/s]
100%|██████████| 141/141 [00:01<00:00, 105.84it/s]


EPOCH -  56 . Train Accuracy =  0.9697843194007874 , Validation Accuracy =  0.8840000033378601


100%|██████████| 797/797 [00:09<00:00, 86.20it/s]
100%|██████████| 141/141 [00:01<00:00, 109.98it/s]


EPOCH -  57 . Train Accuracy =  0.9686862826347351 , Validation Accuracy =  0.8877778053283691
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 84.73it/s]
100%|██████████| 141/141 [00:01<00:00, 109.02it/s]


EPOCH -  58 . Train Accuracy =  0.9699804186820984 , Validation Accuracy =  0.883222222328186


100%|██████████| 797/797 [00:09<00:00, 86.95it/s]
100%|██████████| 141/141 [00:01<00:00, 109.00it/s]


EPOCH -  59 . Train Accuracy =  0.9699608087539673 , Validation Accuracy =  0.8861111402511597


100%|██████████| 797/797 [00:09<00:00, 86.50it/s]
100%|██████████| 141/141 [00:01<00:00, 109.29it/s]


EPOCH -  60 . Train Accuracy =  0.9708039164543152 , Validation Accuracy =  0.882888913154602


100%|██████████| 797/797 [00:10<00:00, 79.00it/s]
100%|██████████| 141/141 [00:01<00:00, 105.61it/s]


EPOCH -  61 . Train Accuracy =  0.9715490341186523 , Validation Accuracy =  0.8785555362701416


100%|██████████| 797/797 [00:08<00:00, 91.24it/s]
100%|██████████| 141/141 [00:01<00:00, 76.69it/s]


EPOCH -  62 . Train Accuracy =  0.9713529348373413 , Validation Accuracy =  0.8848888874053955


100%|██████████| 797/797 [00:08<00:00, 90.90it/s]
100%|██████████| 141/141 [00:01<00:00, 107.84it/s]


EPOCH -  63 . Train Accuracy =  0.9713529348373413 , Validation Accuracy =  0.8843333721160889


100%|██████████| 797/797 [00:09<00:00, 85.98it/s]
100%|██████████| 141/141 [00:01<00:00, 107.25it/s]


EPOCH -  64 . Train Accuracy =  0.9731960892677307 , Validation Accuracy =  0.8835555911064148


100%|██████████| 797/797 [00:09<00:00, 85.42it/s]
100%|██████████| 141/141 [00:01<00:00, 109.75it/s]


EPOCH -  65 . Train Accuracy =  0.9725882411003113 , Validation Accuracy =  0.8772222399711609


100%|██████████| 797/797 [00:09<00:00, 84.18it/s]
100%|██████████| 141/141 [00:01<00:00, 110.48it/s]


EPOCH -  66 . Train Accuracy =  0.9732941389083862 , Validation Accuracy =  0.8759999871253967


100%|██████████| 797/797 [00:09<00:00, 85.12it/s]
100%|██████████| 141/141 [00:01<00:00, 108.20it/s]


EPOCH -  67 . Train Accuracy =  0.9737451076507568 , Validation Accuracy =  0.886222243309021


100%|██████████| 797/797 [00:09<00:00, 84.35it/s]
100%|██████████| 141/141 [00:01<00:00, 110.44it/s]


EPOCH -  68 . Train Accuracy =  0.974823534488678 , Validation Accuracy =  0.8831111192703247


100%|██████████| 797/797 [00:09<00:00, 87.60it/s]
100%|██████████| 141/141 [00:01<00:00, 87.81it/s] 


EPOCH -  69 . Train Accuracy =  0.9761568903923035 , Validation Accuracy =  0.8793333172798157


100%|██████████| 797/797 [00:08<00:00, 92.50it/s]
100%|██████████| 141/141 [00:01<00:00, 91.33it/s]


EPOCH -  70 . Train Accuracy =  0.9743529558181763 , Validation Accuracy =  0.8745555877685547


100%|██████████| 797/797 [00:08<00:00, 89.45it/s]
100%|██████████| 141/141 [00:01<00:00, 108.83it/s]


EPOCH -  71 . Train Accuracy =  0.9752157330513 , Validation Accuracy =  0.8845555782318115


100%|██████████| 797/797 [00:09<00:00, 86.18it/s]
100%|██████████| 141/141 [00:01<00:00, 105.89it/s]


EPOCH -  72 . Train Accuracy =  0.9757451415061951 , Validation Accuracy =  0.8799999952316284


100%|██████████| 797/797 [00:09<00:00, 84.66it/s]
100%|██████████| 141/141 [00:01<00:00, 106.06it/s]


EPOCH -  73 . Train Accuracy =  0.9777255058288574 , Validation Accuracy =  0.8854444622993469


100%|██████████| 797/797 [00:09<00:00, 85.00it/s]
100%|██████████| 141/141 [00:01<00:00, 109.86it/s]


EPOCH -  74 . Train Accuracy =  0.9775294065475464 , Validation Accuracy =  0.8715555667877197


100%|██████████| 797/797 [00:09<00:00, 85.27it/s]
100%|██████████| 141/141 [00:01<00:00, 108.93it/s]


EPOCH -  75 . Train Accuracy =  0.977843165397644 , Validation Accuracy =  0.8848888874053955


100%|██████████| 797/797 [00:09<00:00, 87.63it/s]
100%|██████████| 141/141 [00:01<00:00, 92.29it/s] 


EPOCH -  76 . Train Accuracy =  0.9772941470146179 , Validation Accuracy =  0.8868889212608337


100%|██████████| 797/797 [00:08<00:00, 91.40it/s]
100%|██████████| 141/141 [00:01<00:00, 82.52it/s]


EPOCH -  77 . Train Accuracy =  0.979607880115509 , Validation Accuracy =  0.8792222142219543


100%|██████████| 797/797 [00:08<00:00, 89.58it/s]
100%|██████████| 141/141 [00:01<00:00, 96.65it/s]


EPOCH -  78 . Train Accuracy =  0.979078471660614 , Validation Accuracy =  0.8798888921737671


100%|██████████| 797/797 [00:10<00:00, 76.98it/s]
100%|██████████| 141/141 [00:01<00:00, 108.75it/s]


EPOCH -  79 . Train Accuracy =  0.9787843227386475 , Validation Accuracy =  0.8836666941642761


100%|██████████| 797/797 [00:09<00:00, 84.95it/s]
100%|██████████| 141/141 [00:01<00:00, 107.71it/s]


EPOCH -  80 . Train Accuracy =  0.9802157282829285 , Validation Accuracy =  0.886222243309021


100%|██████████| 797/797 [00:09<00:00, 85.67it/s]
100%|██████████| 141/141 [00:01<00:00, 107.91it/s]


EPOCH -  81 . Train Accuracy =  0.979941189289093 , Validation Accuracy =  0.8854444622993469


100%|██████████| 797/797 [00:09<00:00, 85.48it/s]
100%|██████████| 141/141 [00:01<00:00, 109.42it/s]


EPOCH -  82 . Train Accuracy =  0.9807647466659546 , Validation Accuracy =  0.8896666765213013
Model Re-Saved


100%|██████████| 797/797 [00:09<00:00, 85.54it/s]
100%|██████████| 141/141 [00:01<00:00, 111.55it/s]


EPOCH -  83 . Train Accuracy =  0.980137288570404 , Validation Accuracy =  0.8849999904632568


100%|██████████| 797/797 [00:08<00:00, 89.66it/s]
100%|██████████| 141/141 [00:01<00:00, 75.71it/s]


EPOCH -  84 . Train Accuracy =  0.9788039326667786 , Validation Accuracy =  0.8812222480773926


100%|██████████| 797/797 [00:08<00:00, 92.36it/s]
100%|██████████| 141/141 [00:01<00:00, 105.98it/s]


EPOCH -  85 . Train Accuracy =  0.9812157154083252 , Validation Accuracy =  0.8841111063957214


100%|██████████| 797/797 [00:09<00:00, 86.45it/s]
100%|██████████| 141/141 [00:01<00:00, 109.57it/s]


EPOCH -  86 . Train Accuracy =  0.9822157025337219 , Validation Accuracy =  0.8872222304344177


100%|██████████| 797/797 [00:09<00:00, 85.95it/s]
100%|██████████| 141/141 [00:01<00:00, 109.49it/s]


EPOCH -  87 . Train Accuracy =  0.9798627495765686 , Validation Accuracy =  0.8843333721160889


100%|██████████| 797/797 [00:09<00:00, 86.49it/s]
100%|██████████| 141/141 [00:01<00:00, 105.12it/s]


EPOCH -  88 . Train Accuracy =  0.9810784459114075 , Validation Accuracy =  0.8845555782318115


100%|██████████| 797/797 [00:09<00:00, 84.82it/s]
100%|██████████| 141/141 [00:01<00:00, 108.14it/s]


EPOCH -  89 . Train Accuracy =  0.9836862683296204 , Validation Accuracy =  0.8834444284439087


100%|██████████| 797/797 [00:09<00:00, 84.65it/s]
100%|██████████| 141/141 [00:01<00:00, 108.28it/s]


EPOCH -  90 . Train Accuracy =  0.9817451238632202 , Validation Accuracy =  0.8896666765213013


100%|██████████| 797/797 [00:08<00:00, 89.52it/s]
100%|██████████| 141/141 [00:01<00:00, 81.48it/s]


EPOCH -  91 . Train Accuracy =  0.9828235507011414 , Validation Accuracy =  0.8834444284439087


100%|██████████| 797/797 [00:08<00:00, 90.60it/s]
100%|██████████| 141/141 [00:01<00:00, 88.04it/s]


EPOCH -  92 . Train Accuracy =  0.9838627576828003 , Validation Accuracy =  0.8840000033378601


100%|██████████| 797/797 [00:09<00:00, 87.69it/s]
100%|██████████| 141/141 [00:01<00:00, 107.44it/s]


EPOCH -  93 . Train Accuracy =  0.9835686683654785 , Validation Accuracy =  0.8853333592414856


100%|██████████| 797/797 [00:09<00:00, 84.16it/s]
100%|██████████| 141/141 [00:01<00:00, 108.64it/s]


EPOCH -  94 . Train Accuracy =  0.9833529591560364 , Validation Accuracy =  0.8825555443763733


100%|██████████| 797/797 [00:09<00:00, 84.88it/s]
100%|██████████| 141/141 [00:01<00:00, 105.83it/s]


EPOCH -  95 . Train Accuracy =  0.9824314117431641 , Validation Accuracy =  0.8863333463668823


100%|██████████| 797/797 [00:09<00:00, 84.56it/s]
100%|██████████| 141/141 [00:01<00:00, 107.71it/s]


EPOCH -  96 . Train Accuracy =  0.9835098385810852 , Validation Accuracy =  0.8880000114440918


100%|██████████| 797/797 [00:10<00:00, 78.53it/s]
100%|██████████| 141/141 [00:01<00:00, 108.40it/s]


EPOCH -  97 . Train Accuracy =  0.9844509959220886 , Validation Accuracy =  0.8870000243186951


100%|██████████| 797/797 [00:09<00:00, 84.52it/s]
100%|██████████| 141/141 [00:01<00:00, 104.69it/s]


EPOCH -  98 . Train Accuracy =  0.9842352867126465 , Validation Accuracy =  0.8825555443763733


100%|██████████| 797/797 [00:09<00:00, 87.03it/s]
100%|██████████| 141/141 [00:01<00:00, 89.27it/s] 


EPOCH -  99 . Train Accuracy =  0.9840784668922424 , Validation Accuracy =  0.8846666812896729


100%|██████████| 797/797 [00:08<00:00, 91.58it/s]
100%|██████████| 141/141 [00:01<00:00, 88.05it/s]


EPOCH -  100 . Train Accuracy =  0.9847647547721863 , Validation Accuracy =  0.8844444751739502


100%|██████████| 797/797 [00:07<00:00, 102.33it/s]


Train Accuracy =  0.9824314117431641


100%|██████████| 157/157 [00:01<00:00, 108.19it/s]

Test Accuracy =  0.8844999670982361
Generalization Gap =  0.09793144464492798





In [None]:
# -------------------------------------------- STL 10 Dataset -----------------------------------------------------------
Option = 3
Train_DataLoader, Val_DataLoader, Test_DataLoader = CreateDataLoaders(Option, 64)
ModelName = "Model_MLP_IOCN_STL10.pt"

model = MLP_IOCN(27648, 10).to(device)
loss_criteria = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(model.parameters(), lr=0.0001)

ValAccuracy = 0

saved = False
if saved == True:
    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)
else:
    model.train()
    Train_Accuracy = 0

    for e in range(EPOCHS):
        Train_Accuracy = TrainModel(model, loss_criteria, AdamOpt, device, Train_DataLoader)
        val_acc = EvaluateModel(model, Val_DataLoader, device)

        print("EPOCH - ", e+1, ". Train Accuracy = ", Train_Accuracy.cpu().item(), ", Validation Accuracy = ", val_acc.cpu().item())
        
        if val_acc.cpu().item() > ValAccuracy:
            print("Model Re-Saved")
            ValAccuracy = val_acc.cpu().item()
            torch.save(model, ModelName)


    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)

print()
Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 67/67 [00:03<00:00, 21.65it/s]
100%|██████████| 12/12 [00:00<00:00, 23.60it/s]


EPOCH -  1 . Train Accuracy =  0.1576470583677292 , Validation Accuracy =  0.14533333480358124
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 27.60it/s]
100%|██████████| 12/12 [00:00<00:00, 35.00it/s]


EPOCH -  2 . Train Accuracy =  0.14541177451610565 , Validation Accuracy =  0.13066665828227997


100%|██████████| 67/67 [00:02<00:00, 32.79it/s]
100%|██████████| 12/12 [00:00<00:00, 36.44it/s]


EPOCH -  3 . Train Accuracy =  0.14658823609352112 , Validation Accuracy =  0.14266666769981384


100%|██████████| 67/67 [00:02<00:00, 31.28it/s]
100%|██████████| 12/12 [00:00<00:00, 35.21it/s]


EPOCH -  4 . Train Accuracy =  0.17929412424564362 , Validation Accuracy =  0.164000004529953
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.20it/s]
100%|██████████| 12/12 [00:00<00:00, 33.99it/s]


EPOCH -  5 . Train Accuracy =  0.18752941489219666 , Validation Accuracy =  0.14399999380111694


100%|██████████| 67/67 [00:02<00:00, 23.78it/s]
100%|██████████| 12/12 [00:00<00:00, 34.47it/s]


EPOCH -  6 . Train Accuracy =  0.19035294651985168 , Validation Accuracy =  0.14666666090488434


100%|██████████| 67/67 [00:02<00:00, 31.42it/s]
100%|██████████| 12/12 [00:00<00:00, 34.79it/s]


EPOCH -  7 . Train Accuracy =  0.18376471102237701 , Validation Accuracy =  0.15733332931995392


100%|██████████| 67/67 [00:02<00:00, 31.95it/s]
100%|██████████| 12/12 [00:00<00:00, 35.59it/s]


EPOCH -  8 . Train Accuracy =  0.18635293841362 , Validation Accuracy =  0.164000004529953


100%|██████████| 67/67 [00:02<00:00, 32.34it/s]
100%|██████████| 12/12 [00:00<00:00, 32.77it/s]


EPOCH -  9 . Train Accuracy =  0.18870589137077332 , Validation Accuracy =  0.1586666703224182


100%|██████████| 67/67 [00:02<00:00, 30.74it/s]
100%|██████████| 12/12 [00:00<00:00, 24.10it/s]


EPOCH -  10 . Train Accuracy =  0.1957647055387497 , Validation Accuracy =  0.1706666648387909
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 27.95it/s]
100%|██████████| 12/12 [00:00<00:00, 35.58it/s]


EPOCH -  11 . Train Accuracy =  0.21529412269592285 , Validation Accuracy =  0.18799999356269836
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 31.25it/s]
100%|██████████| 12/12 [00:00<00:00, 35.61it/s]


EPOCH -  12 . Train Accuracy =  0.22752942144870758 , Validation Accuracy =  0.164000004529953


100%|██████████| 67/67 [00:02<00:00, 31.84it/s]
100%|██████████| 12/12 [00:00<00:00, 33.68it/s]


EPOCH -  13 . Train Accuracy =  0.2329411804676056 , Validation Accuracy =  0.17866666615009308


100%|██████████| 67/67 [00:02<00:00, 32.84it/s]
100%|██████████| 12/12 [00:00<00:00, 34.84it/s]


EPOCH -  14 . Train Accuracy =  0.24047058820724487 , Validation Accuracy =  0.1613333374261856


100%|██████████| 67/67 [00:02<00:00, 24.50it/s]
100%|██████████| 12/12 [00:00<00:00, 29.47it/s]


EPOCH -  15 . Train Accuracy =  0.23129412531852722 , Validation Accuracy =  0.1693333387374878


100%|██████████| 67/67 [00:02<00:00, 31.57it/s]
100%|██████████| 12/12 [00:00<00:00, 35.71it/s]


EPOCH -  16 . Train Accuracy =  0.24211765825748444 , Validation Accuracy =  0.17999999225139618


100%|██████████| 67/67 [00:02<00:00, 30.51it/s]
100%|██████████| 12/12 [00:00<00:00, 34.31it/s]


EPOCH -  17 . Train Accuracy =  0.2508235275745392 , Validation Accuracy =  0.1746666580438614


100%|██████████| 67/67 [00:02<00:00, 28.76it/s]
100%|██████████| 12/12 [00:00<00:00, 25.33it/s]


EPOCH -  18 . Train Accuracy =  0.2571764886379242 , Validation Accuracy =  0.17599999904632568


100%|██████████| 67/67 [00:02<00:00, 31.37it/s]
100%|██████████| 12/12 [00:00<00:00, 24.66it/s]


EPOCH -  19 . Train Accuracy =  0.2616470754146576 , Validation Accuracy =  0.1679999977350235


100%|██████████| 67/67 [00:02<00:00, 25.95it/s]
100%|██████████| 12/12 [00:00<00:00, 33.92it/s]


EPOCH -  20 . Train Accuracy =  0.2691764831542969 , Validation Accuracy =  0.18933333456516266
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.13it/s]
100%|██████████| 12/12 [00:00<00:00, 34.84it/s]


EPOCH -  21 . Train Accuracy =  0.288470596075058 , Validation Accuracy =  0.19066666066646576
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 31.65it/s]
100%|██████████| 12/12 [00:00<00:00, 34.28it/s]


EPOCH -  22 . Train Accuracy =  0.2927058935165405 , Validation Accuracy =  0.17733332514762878


100%|██████████| 67/67 [00:02<00:00, 31.32it/s]
100%|██████████| 12/12 [00:00<00:00, 33.23it/s]


EPOCH -  23 . Train Accuracy =  0.30941176414489746 , Validation Accuracy =  0.19200000166893005
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 23.45it/s]
100%|██████████| 12/12 [00:00<00:00, 33.53it/s]


EPOCH -  24 . Train Accuracy =  0.32141175866127014 , Validation Accuracy =  0.19066666066646576


100%|██████████| 67/67 [00:02<00:00, 32.62it/s]
100%|██████████| 12/12 [00:00<00:00, 36.48it/s]


EPOCH -  25 . Train Accuracy =  0.33317646384239197 , Validation Accuracy =  0.18666666746139526


100%|██████████| 67/67 [00:02<00:00, 32.43it/s]
100%|██████████| 12/12 [00:00<00:00, 34.26it/s]


EPOCH -  26 . Train Accuracy =  0.34705883264541626 , Validation Accuracy =  0.19466666877269745
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 31.80it/s]
100%|██████████| 12/12 [00:00<00:00, 35.46it/s]


EPOCH -  27 . Train Accuracy =  0.36400002241134644 , Validation Accuracy =  0.20399999618530273
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 29.38it/s]
100%|██████████| 12/12 [00:00<00:00, 25.13it/s]


EPOCH -  28 . Train Accuracy =  0.3611764907836914 , Validation Accuracy =  0.20533333718776703
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 27.04it/s]
100%|██████████| 12/12 [00:00<00:00, 32.92it/s]


EPOCH -  29 . Train Accuracy =  0.36635294556617737 , Validation Accuracy =  0.19733333587646484


100%|██████████| 67/67 [00:02<00:00, 31.96it/s]
100%|██████████| 12/12 [00:00<00:00, 34.28it/s]


EPOCH -  30 . Train Accuracy =  0.4018823504447937 , Validation Accuracy =  0.19066666066646576


100%|██████████| 67/67 [00:02<00:00, 31.80it/s]
100%|██████████| 12/12 [00:00<00:00, 34.23it/s]


EPOCH -  31 . Train Accuracy =  0.40847060084342957 , Validation Accuracy =  0.19599999487400055


100%|██████████| 67/67 [00:02<00:00, 31.40it/s]
100%|██████████| 12/12 [00:00<00:00, 32.05it/s]


EPOCH -  32 . Train Accuracy =  0.4037647247314453 , Validation Accuracy =  0.2173333317041397
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 24.12it/s]
100%|██████████| 12/12 [00:00<00:00, 33.44it/s]


EPOCH -  33 . Train Accuracy =  0.4355294108390808 , Validation Accuracy =  0.2186666578054428
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.19it/s]
100%|██████████| 12/12 [00:00<00:00, 34.93it/s]


EPOCH -  34 . Train Accuracy =  0.43670588731765747 , Validation Accuracy =  0.2133333384990692


100%|██████████| 67/67 [00:02<00:00, 32.18it/s]
100%|██████████| 12/12 [00:00<00:00, 34.89it/s]


EPOCH -  35 . Train Accuracy =  0.4508235454559326 , Validation Accuracy =  0.2253333330154419
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.15it/s]
100%|██████████| 12/12 [00:00<00:00, 35.49it/s]


EPOCH -  36 . Train Accuracy =  0.46705883741378784 , Validation Accuracy =  0.20666666328907013


100%|██████████| 67/67 [00:02<00:00, 29.86it/s]
100%|██████████| 12/12 [00:00<00:00, 23.93it/s]


EPOCH -  37 . Train Accuracy =  0.4515294134616852 , Validation Accuracy =  0.21599999070167542


100%|██████████| 67/67 [00:02<00:00, 26.62it/s]
100%|██████████| 12/12 [00:00<00:00, 35.80it/s]


EPOCH -  38 . Train Accuracy =  0.48541176319122314 , Validation Accuracy =  0.2186666578054428


100%|██████████| 67/67 [00:02<00:00, 32.46it/s]
100%|██████████| 12/12 [00:00<00:00, 35.13it/s]


EPOCH -  39 . Train Accuracy =  0.4988235533237457 , Validation Accuracy =  0.19866666197776794


100%|██████████| 67/67 [00:02<00:00, 32.38it/s]
100%|██████████| 12/12 [00:00<00:00, 34.74it/s]


EPOCH -  40 . Train Accuracy =  0.5018823742866516 , Validation Accuracy =  0.20133332908153534


100%|██████████| 67/67 [00:02<00:00, 32.64it/s]
100%|██████████| 12/12 [00:00<00:00, 36.15it/s]


EPOCH -  41 . Train Accuracy =  0.5134117603302002 , Validation Accuracy =  0.2253333330154419


100%|██████████| 67/67 [00:02<00:00, 27.15it/s]
100%|██████████| 12/12 [00:00<00:00, 24.36it/s]


EPOCH -  42 . Train Accuracy =  0.5176470875740051 , Validation Accuracy =  0.21066667139530182


100%|██████████| 67/67 [00:02<00:00, 28.50it/s]
100%|██████████| 12/12 [00:00<00:00, 34.99it/s]


EPOCH -  43 . Train Accuracy =  0.5341176390647888 , Validation Accuracy =  0.21599999070167542


100%|██████████| 67/67 [00:02<00:00, 32.73it/s]
100%|██████████| 12/12 [00:00<00:00, 35.54it/s]


EPOCH -  44 . Train Accuracy =  0.5360000133514404 , Validation Accuracy =  0.2239999920129776


100%|██████████| 67/67 [00:02<00:00, 32.71it/s]
100%|██████████| 12/12 [00:00<00:00, 34.94it/s]


EPOCH -  45 . Train Accuracy =  0.5823529362678528 , Validation Accuracy =  0.226666659116745
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.42it/s]
100%|██████████| 12/12 [00:00<00:00, 33.55it/s]


EPOCH -  46 . Train Accuracy =  0.5849412083625793 , Validation Accuracy =  0.23733332753181458
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 23.63it/s]
100%|██████████| 12/12 [00:00<00:00, 34.63it/s]


EPOCH -  47 . Train Accuracy =  0.6089411973953247 , Validation Accuracy =  0.23733332753181458


100%|██████████| 67/67 [00:02<00:00, 31.50it/s]
100%|██████████| 12/12 [00:00<00:00, 34.51it/s]


EPOCH -  48 . Train Accuracy =  0.5887058973312378 , Validation Accuracy =  0.24400000274181366
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.55it/s]
100%|██████████| 12/12 [00:00<00:00, 34.23it/s]


EPOCH -  49 . Train Accuracy =  0.6110588312149048 , Validation Accuracy =  0.226666659116745


100%|██████████| 67/67 [00:02<00:00, 32.03it/s]
100%|██████████| 12/12 [00:00<00:00, 35.28it/s]


EPOCH -  50 . Train Accuracy =  0.6439999938011169 , Validation Accuracy =  0.24266666173934937


100%|██████████| 67/67 [00:02<00:00, 30.47it/s]
100%|██████████| 12/12 [00:00<00:00, 23.50it/s]


EPOCH -  51 . Train Accuracy =  0.6444706320762634 , Validation Accuracy =  0.23733332753181458


100%|██████████| 67/67 [00:02<00:00, 26.14it/s]
100%|██████████| 12/12 [00:00<00:00, 33.24it/s]


EPOCH -  52 . Train Accuracy =  0.6687058806419373 , Validation Accuracy =  0.25866666436195374
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 31.34it/s]
100%|██████████| 12/12 [00:00<00:00, 34.71it/s]


EPOCH -  53 . Train Accuracy =  0.658823549747467 , Validation Accuracy =  0.23999999463558197


100%|██████████| 67/67 [00:02<00:00, 32.56it/s]
100%|██████████| 12/12 [00:00<00:00, 35.37it/s]


EPOCH -  54 . Train Accuracy =  0.6564705967903137 , Validation Accuracy =  0.24133333563804626


100%|██████████| 67/67 [00:02<00:00, 33.19it/s]
100%|██████████| 12/12 [00:00<00:00, 35.97it/s]


EPOCH -  55 . Train Accuracy =  0.6840000152587891 , Validation Accuracy =  0.23999999463558197


100%|██████████| 67/67 [00:02<00:00, 26.41it/s]
100%|██████████| 12/12 [00:00<00:00, 24.23it/s]


EPOCH -  56 . Train Accuracy =  0.6842353343963623 , Validation Accuracy =  0.25466665625572205


100%|██████████| 67/67 [00:02<00:00, 31.13it/s]
100%|██████████| 12/12 [00:00<00:00, 35.94it/s]


EPOCH -  57 . Train Accuracy =  0.6910588145256042 , Validation Accuracy =  0.24799999594688416


100%|██████████| 67/67 [00:02<00:00, 30.69it/s]
100%|██████████| 12/12 [00:00<00:00, 35.58it/s]


EPOCH -  58 . Train Accuracy =  0.7277647256851196 , Validation Accuracy =  0.24933333694934845


100%|██████████| 67/67 [00:02<00:00, 32.19it/s]
100%|██████████| 12/12 [00:00<00:00, 35.56it/s]


EPOCH -  59 . Train Accuracy =  0.7428235411643982 , Validation Accuracy =  0.24266666173934937


100%|██████████| 67/67 [00:02<00:00, 33.05it/s]
100%|██████████| 12/12 [00:00<00:00, 33.80it/s]


EPOCH -  60 . Train Accuracy =  0.7458823919296265 , Validation Accuracy =  0.23333333432674408


100%|██████████| 67/67 [00:02<00:00, 24.06it/s]
100%|██████████| 12/12 [00:00<00:00, 35.91it/s]


EPOCH -  61 . Train Accuracy =  0.7268235683441162 , Validation Accuracy =  0.24400000274181366


100%|██████████| 67/67 [00:02<00:00, 32.13it/s]
100%|██████████| 12/12 [00:00<00:00, 34.76it/s]


EPOCH -  62 . Train Accuracy =  0.7658823728561401 , Validation Accuracy =  0.2719999849796295
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.02it/s]
100%|██████████| 12/12 [00:00<00:00, 33.02it/s]


EPOCH -  63 . Train Accuracy =  0.784000039100647 , Validation Accuracy =  0.281333327293396
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.24it/s]
100%|██████████| 12/12 [00:00<00:00, 33.91it/s]


EPOCH -  64 . Train Accuracy =  0.7837647199630737 , Validation Accuracy =  0.2639999985694885


100%|██████████| 67/67 [00:02<00:00, 29.36it/s]
100%|██████████| 12/12 [00:00<00:00, 23.25it/s]


EPOCH -  65 . Train Accuracy =  0.7687059044837952 , Validation Accuracy =  0.24666666984558105


100%|██████████| 67/67 [00:02<00:00, 27.75it/s]
100%|██████████| 12/12 [00:00<00:00, 35.62it/s]


EPOCH -  66 . Train Accuracy =  0.7621176838874817 , Validation Accuracy =  0.2773333191871643


100%|██████████| 67/67 [00:02<00:00, 32.70it/s]
100%|██████████| 12/12 [00:00<00:00, 34.26it/s]


EPOCH -  67 . Train Accuracy =  0.7703529596328735 , Validation Accuracy =  0.2826666533946991
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.44it/s]
100%|██████████| 12/12 [00:00<00:00, 34.09it/s]


EPOCH -  68 . Train Accuracy =  0.8011764883995056 , Validation Accuracy =  0.2786666750907898


100%|██████████| 67/67 [00:02<00:00, 30.46it/s]
100%|██████████| 12/12 [00:00<00:00, 35.55it/s]


EPOCH -  69 . Train Accuracy =  0.8136470913887024 , Validation Accuracy =  0.281333327293396


100%|██████████| 67/67 [00:02<00:00, 24.79it/s]
100%|██████████| 12/12 [00:00<00:00, 27.11it/s]


EPOCH -  70 . Train Accuracy =  0.8150588274002075 , Validation Accuracy =  0.2706666588783264


100%|██████████| 67/67 [00:02<00:00, 32.31it/s]
100%|██████████| 12/12 [00:00<00:00, 36.23it/s]


EPOCH -  71 . Train Accuracy =  0.821647047996521 , Validation Accuracy =  0.281333327293396


100%|██████████| 67/67 [00:02<00:00, 33.17it/s]
100%|██████████| 12/12 [00:00<00:00, 35.12it/s]


EPOCH -  72 . Train Accuracy =  0.8209412097930908 , Validation Accuracy =  0.2626666724681854


100%|██████████| 67/67 [00:02<00:00, 32.62it/s]
100%|██████████| 12/12 [00:00<00:00, 34.42it/s]


EPOCH -  73 . Train Accuracy =  0.8240000009536743 , Validation Accuracy =  0.2626666724681854


100%|██████████| 67/67 [00:02<00:00, 32.29it/s]
100%|██████████| 12/12 [00:00<00:00, 34.33it/s]


EPOCH -  74 . Train Accuracy =  0.8261176943778992 , Validation Accuracy =  0.2759999930858612


100%|██████████| 67/67 [00:02<00:00, 24.05it/s]
100%|██████████| 12/12 [00:00<00:00, 35.38it/s]


EPOCH -  75 . Train Accuracy =  0.8192941546440125 , Validation Accuracy =  0.2826666533946991


100%|██████████| 67/67 [00:02<00:00, 32.55it/s]
100%|██████████| 12/12 [00:00<00:00, 36.06it/s]


EPOCH -  76 . Train Accuracy =  0.850352942943573 , Validation Accuracy =  0.29466667771339417
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.55it/s]
100%|██████████| 12/12 [00:00<00:00, 35.78it/s]


EPOCH -  77 . Train Accuracy =  0.8607059121131897 , Validation Accuracy =  0.273333340883255


100%|██████████| 67/67 [00:02<00:00, 33.13it/s]
100%|██████████| 12/12 [00:00<00:00, 35.57it/s]


EPOCH -  78 . Train Accuracy =  0.8432941436767578 , Validation Accuracy =  0.273333340883255


100%|██████████| 67/67 [00:02<00:00, 30.19it/s]
100%|██████████| 12/12 [00:00<00:00, 22.31it/s]


EPOCH -  79 . Train Accuracy =  0.8461176753044128 , Validation Accuracy =  0.29600000381469727
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 27.77it/s]
100%|██████████| 12/12 [00:00<00:00, 36.19it/s]


EPOCH -  80 . Train Accuracy =  0.8687059283256531 , Validation Accuracy =  0.2800000011920929


100%|██████████| 67/67 [00:02<00:00, 32.38it/s]
100%|██████████| 12/12 [00:00<00:00, 34.36it/s]


EPOCH -  81 . Train Accuracy =  0.8891764879226685 , Validation Accuracy =  0.2759999930858612


100%|██████████| 67/67 [00:02<00:00, 32.91it/s]
100%|██████████| 12/12 [00:00<00:00, 36.08it/s]


EPOCH -  82 . Train Accuracy =  0.8743529319763184 , Validation Accuracy =  0.2853333353996277


100%|██████████| 67/67 [00:02<00:00, 32.75it/s]
100%|██████████| 12/12 [00:00<00:00, 35.33it/s]


EPOCH -  83 . Train Accuracy =  0.8771764636039734 , Validation Accuracy =  0.2773333191871643


100%|██████████| 67/67 [00:02<00:00, 25.64it/s]
100%|██████████| 12/12 [00:00<00:00, 24.09it/s]


EPOCH -  84 . Train Accuracy =  0.8392941355705261 , Validation Accuracy =  0.2919999957084656


100%|██████████| 67/67 [00:02<00:00, 30.75it/s]
100%|██████████| 12/12 [00:00<00:00, 32.81it/s]


EPOCH -  85 . Train Accuracy =  0.8823529481887817 , Validation Accuracy =  0.30666667222976685
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 32.40it/s]
100%|██████████| 12/12 [00:00<00:00, 35.16it/s]


EPOCH -  86 . Train Accuracy =  0.8837647438049316 , Validation Accuracy =  0.2826666533946991


100%|██████████| 67/67 [00:02<00:00, 32.38it/s]
100%|██████████| 12/12 [00:00<00:00, 36.27it/s]


EPOCH -  87 . Train Accuracy =  0.8988235592842102 , Validation Accuracy =  0.31599998474121094
Model Re-Saved


100%|██████████| 67/67 [00:02<00:00, 31.91it/s]
100%|██████████| 12/12 [00:00<00:00, 31.34it/s]


EPOCH -  88 . Train Accuracy =  0.882588267326355 , Validation Accuracy =  0.30266666412353516


100%|██████████| 67/67 [00:02<00:00, 24.11it/s]
100%|██████████| 12/12 [00:00<00:00, 34.22it/s]


EPOCH -  89 . Train Accuracy =  0.9018823504447937 , Validation Accuracy =  0.30133333802223206


100%|██████████| 67/67 [00:02<00:00, 31.40it/s]
100%|██████████| 12/12 [00:00<00:00, 35.41it/s]


EPOCH -  90 . Train Accuracy =  0.9025882482528687 , Validation Accuracy =  0.30399999022483826


  0%|          | 0/67 [00:00<?, ?it/s]

In [8]:
# -------------------------------------------- CIFAR-10 Dataset -----------------------------------------------------------
Option = 5
Train_DataLoader, Val_DataLoader, Test_DataLoader = CreateDataLoaders(Option, 64)
ModelName = "Model_MLP_IOCN_CIFAR10.pt"

model = MLP_IOCN(3072, 10).to(device)
loss_criteria = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(model.parameters(), lr=0.0001)

ValAccuracy = 0

saved = False
if saved == True:
    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)
else:
    model.train()
    Train_Accuracy = 0

    for e in range(EPOCHS):
        Train_Accuracy = TrainModel(model, loss_criteria, AdamOpt, device, Train_DataLoader)
        val_acc = EvaluateModel(model, Val_DataLoader, device)

        print("EPOCH - ", e+1, ". Train Accuracy = ", Train_Accuracy.cpu().item(), ", Validation Accuracy = ", val_acc.cpu().item())
        
        if val_acc.cpu().item() > ValAccuracy:
            print("Model Re-Saved")
            ValAccuracy = val_acc.cpu().item()
            torch.save(model, ModelName)

    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)


print()

Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 40515881.28it/s]


Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


100%|██████████| 665/665 [00:11<00:00, 57.63it/s]
100%|██████████| 118/118 [00:01<00:00, 88.33it/s]


EPOCH -  1 . Train Accuracy =  0.3648470640182495 , Validation Accuracy =  0.39399999380111694
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 72.47it/s]
100%|██████████| 118/118 [00:01<00:00, 86.81it/s]


EPOCH -  2 . Train Accuracy =  0.42494118213653564 , Validation Accuracy =  0.44493332505226135
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 72.01it/s]
100%|██████████| 118/118 [00:01<00:00, 85.72it/s]


EPOCH -  3 . Train Accuracy =  0.462799996137619 , Validation Accuracy =  0.4525333344936371
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 72.47it/s]
100%|██████████| 118/118 [00:01<00:00, 88.14it/s]


EPOCH -  4 . Train Accuracy =  0.4799294173717499 , Validation Accuracy =  0.47253334522247314
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 72.08it/s]
100%|██████████| 118/118 [00:01<00:00, 86.76it/s]


EPOCH -  5 . Train Accuracy =  0.49658823013305664 , Validation Accuracy =  0.47040000557899475


100%|██████████| 665/665 [00:08<00:00, 74.53it/s]
100%|██████████| 118/118 [00:01<00:00, 74.76it/s]


EPOCH -  6 . Train Accuracy =  0.5126588344573975 , Validation Accuracy =  0.47386667132377625
Model Re-Saved


100%|██████████| 665/665 [00:08<00:00, 76.94it/s]
100%|██████████| 118/118 [00:01<00:00, 66.14it/s]


EPOCH -  7 . Train Accuracy =  0.522635281085968 , Validation Accuracy =  0.4742666780948639
Model Re-Saved


100%|██████████| 665/665 [00:08<00:00, 74.54it/s]
100%|██████████| 118/118 [00:01<00:00, 87.57it/s]


EPOCH -  8 . Train Accuracy =  0.5322823524475098 , Validation Accuracy =  0.48653334379196167
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 71.07it/s]
100%|██████████| 118/118 [00:01<00:00, 87.06it/s]


EPOCH -  9 . Train Accuracy =  0.5412235260009766 , Validation Accuracy =  0.481466680765152


100%|██████████| 665/665 [00:09<00:00, 71.63it/s]
100%|██████████| 118/118 [00:01<00:00, 88.10it/s]


EPOCH -  10 . Train Accuracy =  0.5566588044166565 , Validation Accuracy =  0.49560001492500305
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 70.56it/s]
100%|██████████| 118/118 [00:01<00:00, 87.16it/s]


EPOCH -  11 . Train Accuracy =  0.5615764856338501 , Validation Accuracy =  0.4896000027656555


100%|██████████| 665/665 [00:09<00:00, 71.72it/s]
100%|██████████| 118/118 [00:01<00:00, 88.37it/s]


EPOCH -  12 . Train Accuracy =  0.5721647143363953 , Validation Accuracy =  0.5154666900634766
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 71.73it/s]
100%|██████████| 118/118 [00:01<00:00, 83.90it/s]


EPOCH -  13 . Train Accuracy =  0.5784705877304077 , Validation Accuracy =  0.5052000284194946


100%|██████████| 665/665 [00:08<00:00, 77.39it/s]
100%|██████████| 118/118 [00:01<00:00, 63.11it/s]


EPOCH -  14 . Train Accuracy =  0.5882823467254639 , Validation Accuracy =  0.5099999904632568


100%|██████████| 665/665 [00:08<00:00, 77.01it/s]
100%|██████████| 118/118 [00:01<00:00, 85.99it/s]


EPOCH -  15 . Train Accuracy =  0.5956941246986389 , Validation Accuracy =  0.4962666630744934


100%|██████████| 665/665 [00:09<00:00, 72.32it/s]
100%|██████████| 118/118 [00:01<00:00, 88.53it/s]


EPOCH -  16 . Train Accuracy =  0.598800003528595 , Validation Accuracy =  0.4981333315372467


100%|██████████| 665/665 [00:09<00:00, 71.51it/s]
100%|██████████| 118/118 [00:01<00:00, 87.97it/s]


EPOCH -  17 . Train Accuracy =  0.611294150352478 , Validation Accuracy =  0.5113333463668823


100%|██████████| 665/665 [00:09<00:00, 71.32it/s]
100%|██████████| 118/118 [00:01<00:00, 88.56it/s]


EPOCH -  18 . Train Accuracy =  0.6138352751731873 , Validation Accuracy =  0.5078666806221008


100%|██████████| 665/665 [00:09<00:00, 69.77it/s]
100%|██████████| 118/118 [00:01<00:00, 85.76it/s]


EPOCH -  19 . Train Accuracy =  0.624047040939331 , Validation Accuracy =  0.5122666954994202


100%|██████████| 665/665 [00:09<00:00, 73.13it/s]
100%|██████████| 118/118 [00:01<00:00, 75.98it/s]


EPOCH -  20 . Train Accuracy =  0.6311529278755188 , Validation Accuracy =  0.5099999904632568


100%|██████████| 665/665 [00:08<00:00, 78.59it/s]
100%|██████████| 118/118 [00:01<00:00, 68.14it/s]


EPOCH -  21 . Train Accuracy =  0.6309882402420044 , Validation Accuracy =  0.5152000188827515


100%|██████████| 665/665 [00:08<00:00, 74.01it/s]
100%|██████████| 118/118 [00:01<00:00, 88.77it/s]


EPOCH -  22 . Train Accuracy =  0.6402117609977722 , Validation Accuracy =  0.5112000107765198


100%|██████████| 665/665 [00:09<00:00, 70.84it/s]
100%|██████████| 118/118 [00:01<00:00, 88.89it/s]


EPOCH -  23 . Train Accuracy =  0.6486823558807373 , Validation Accuracy =  0.5146666765213013


100%|██████████| 665/665 [00:09<00:00, 71.32it/s]
100%|██████████| 118/118 [00:01<00:00, 84.78it/s]


EPOCH -  24 . Train Accuracy =  0.6584705710411072 , Validation Accuracy =  0.5008000135421753


100%|██████████| 665/665 [00:09<00:00, 70.67it/s]
100%|██████████| 118/118 [00:01<00:00, 87.90it/s]


EPOCH -  25 . Train Accuracy =  0.658258855342865 , Validation Accuracy =  0.5180000066757202
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 71.76it/s]
100%|██████████| 118/118 [00:01<00:00, 85.44it/s]


EPOCH -  26 . Train Accuracy =  0.6686588525772095 , Validation Accuracy =  0.518666684627533
Model Re-Saved


100%|██████████| 665/665 [00:09<00:00, 73.25it/s]
100%|██████████| 118/118 [00:01<00:00, 76.12it/s]


EPOCH -  27 . Train Accuracy =  0.6708706021308899 , Validation Accuracy =  0.5198666453361511
Model Re-Saved


100%|██████████| 665/665 [00:08<00:00, 77.44it/s]
100%|██████████| 118/118 [00:01<00:00, 66.45it/s]


EPOCH -  28 . Train Accuracy =  0.6778117418289185 , Validation Accuracy =  0.5311999917030334
Model Re-Saved


100%|██████████| 665/665 [00:08<00:00, 75.66it/s]
100%|██████████| 118/118 [00:01<00:00, 89.39it/s]


EPOCH -  29 . Train Accuracy =  0.6877176761627197 , Validation Accuracy =  0.5253333449363708


100%|██████████| 665/665 [00:09<00:00, 71.47it/s]
100%|██████████| 118/118 [00:01<00:00, 88.45it/s]


EPOCH -  30 . Train Accuracy =  0.6893882155418396 , Validation Accuracy =  0.5152000188827515


100%|██████████| 665/665 [00:09<00:00, 71.14it/s]
100%|██████████| 118/118 [00:01<00:00, 93.19it/s]


EPOCH -  31 . Train Accuracy =  0.6982117891311646 , Validation Accuracy =  0.515333354473114


100%|██████████| 665/665 [00:09<00:00, 72.20it/s]
100%|██████████| 118/118 [00:01<00:00, 88.80it/s]


EPOCH -  32 . Train Accuracy =  0.7021411657333374 , Validation Accuracy =  0.5193333625793457


100%|██████████| 665/665 [00:09<00:00, 72.22it/s]
100%|██████████| 118/118 [00:01<00:00, 86.36it/s]


EPOCH -  33 . Train Accuracy =  0.7053882479667664 , Validation Accuracy =  0.5216000080108643


100%|██████████| 665/665 [00:08<00:00, 75.10it/s]
100%|██████████| 118/118 [00:01<00:00, 69.18it/s]


EPOCH -  34 . Train Accuracy =  0.7128000259399414 , Validation Accuracy =  0.5326666831970215
Model Re-Saved


100%|██████████| 665/665 [00:08<00:00, 78.79it/s]
100%|██████████| 118/118 [00:01<00:00, 76.96it/s]


EPOCH -  35 . Train Accuracy =  0.7146588563919067 , Validation Accuracy =  0.5233333110809326


100%|██████████| 665/665 [00:09<00:00, 73.61it/s]
100%|██████████| 118/118 [00:01<00:00, 92.62it/s]


EPOCH -  36 . Train Accuracy =  0.7188941240310669 , Validation Accuracy =  0.5281333327293396


100%|██████████| 665/665 [00:09<00:00, 71.58it/s]
100%|██████████| 118/118 [00:01<00:00, 84.96it/s]


EPOCH -  37 . Train Accuracy =  0.7235764861106873 , Validation Accuracy =  0.5171999931335449


100%|██████████| 665/665 [00:09<00:00, 72.79it/s]
100%|██████████| 118/118 [00:01<00:00, 91.50it/s]


EPOCH -  38 . Train Accuracy =  0.7299529314041138 , Validation Accuracy =  0.5089333653450012


100%|██████████| 665/665 [00:09<00:00, 73.72it/s]
100%|██████████| 118/118 [00:01<00:00, 89.93it/s]


EPOCH -  39 . Train Accuracy =  0.7355294227600098 , Validation Accuracy =  0.5221333503723145


100%|██████████| 665/665 [00:09<00:00, 73.01it/s]
100%|██████████| 118/118 [00:01<00:00, 89.10it/s]


EPOCH -  40 . Train Accuracy =  0.7421176433563232 , Validation Accuracy =  0.5206666588783264


100%|██████████| 665/665 [00:08<00:00, 78.30it/s]
100%|██████████| 118/118 [00:01<00:00, 63.04it/s]


EPOCH -  41 . Train Accuracy =  0.7466588020324707 , Validation Accuracy =  0.5225333571434021


100%|██████████| 665/665 [00:08<00:00, 76.84it/s]
100%|██████████| 118/118 [00:01<00:00, 87.69it/s]


EPOCH -  42 . Train Accuracy =  0.7500705718994141 , Validation Accuracy =  0.5166666507720947


100%|██████████| 665/665 [00:09<00:00, 71.63it/s]
100%|██████████| 118/118 [00:01<00:00, 90.51it/s]


EPOCH -  43 . Train Accuracy =  0.7539294362068176 , Validation Accuracy =  0.5254666805267334


100%|██████████| 665/665 [00:09<00:00, 71.10it/s]
100%|██████████| 118/118 [00:01<00:00, 88.06it/s]


EPOCH -  44 . Train Accuracy =  0.7570823431015015 , Validation Accuracy =  0.5198666453361511


100%|██████████| 665/665 [00:09<00:00, 72.99it/s]
100%|██████████| 118/118 [00:01<00:00, 90.24it/s]


EPOCH -  45 . Train Accuracy =  0.7607529759407043 , Validation Accuracy =  0.522266685962677


100%|██████████| 665/665 [00:09<00:00, 73.43it/s]
100%|██████████| 118/118 [00:01<00:00, 88.89it/s]


EPOCH -  46 . Train Accuracy =  0.7680705785751343 , Validation Accuracy =  0.5225333571434021


100%|██████████| 665/665 [00:08<00:00, 75.59it/s]
100%|██████████| 118/118 [00:01<00:00, 69.93it/s]


EPOCH -  47 . Train Accuracy =  0.7744235396385193 , Validation Accuracy =  0.5157333612442017


100%|██████████| 665/665 [00:08<00:00, 79.04it/s]
100%|██████████| 118/118 [00:01<00:00, 80.42it/s]


EPOCH -  48 . Train Accuracy =  0.7747764587402344 , Validation Accuracy =  0.5108000040054321


100%|██████████| 665/665 [00:09<00:00, 73.42it/s]
100%|██████████| 118/118 [00:01<00:00, 86.64it/s]


EPOCH -  49 . Train Accuracy =  0.7745882272720337 , Validation Accuracy =  0.5130666494369507


100%|██████████| 665/665 [00:09<00:00, 71.58it/s]
100%|██████████| 118/118 [00:01<00:00, 90.77it/s]


EPOCH -  50 . Train Accuracy =  0.7800235152244568 , Validation Accuracy =  0.5097333192825317


100%|██████████| 665/665 [00:09<00:00, 72.24it/s]
100%|██████████| 118/118 [00:01<00:00, 91.55it/s]


EPOCH -  51 . Train Accuracy =  0.7868705987930298 , Validation Accuracy =  0.5188000202178955


100%|██████████| 665/665 [00:09<00:00, 72.29it/s]
100%|██████████| 118/118 [00:01<00:00, 89.57it/s]


EPOCH -  52 . Train Accuracy =  0.7923294305801392 , Validation Accuracy =  0.5089333653450012


100%|██████████| 665/665 [00:09<00:00, 72.15it/s]
100%|██████████| 118/118 [00:01<00:00, 89.25it/s]


EPOCH -  53 . Train Accuracy =  0.7865176796913147 , Validation Accuracy =  0.5042666792869568


100%|██████████| 665/665 [00:08<00:00, 78.14it/s]
100%|██████████| 118/118 [00:01<00:00, 61.47it/s]


EPOCH -  54 . Train Accuracy =  0.7923764586448669 , Validation Accuracy =  0.5030666589736938


100%|██████████| 665/665 [00:08<00:00, 77.69it/s]
100%|██████████| 118/118 [00:01<00:00, 90.94it/s]


EPOCH -  55 . Train Accuracy =  0.8015058636665344 , Validation Accuracy =  0.5144000053405762


100%|██████████| 665/665 [00:09<00:00, 71.48it/s]
100%|██████████| 118/118 [00:01<00:00, 90.25it/s]


EPOCH -  56 . Train Accuracy =  0.7957412004470825 , Validation Accuracy =  0.5185333490371704


100%|██████████| 665/665 [00:09<00:00, 72.67it/s]
100%|██████████| 118/118 [00:01<00:00, 91.34it/s]


EPOCH -  57 . Train Accuracy =  0.8048000335693359 , Validation Accuracy =  0.5178666710853577


100%|██████████| 665/665 [00:09<00:00, 72.16it/s]
100%|██████████| 118/118 [00:01<00:00, 88.14it/s]


EPOCH -  58 . Train Accuracy =  0.8111059069633484 , Validation Accuracy =  0.5134666562080383


100%|██████████| 665/665 [00:09<00:00, 71.00it/s]
100%|██████████| 118/118 [00:01<00:00, 88.35it/s]


EPOCH -  59 . Train Accuracy =  0.8132941126823425 , Validation Accuracy =  0.5173333287239075


100%|██████████| 665/665 [00:08<00:00, 74.96it/s]
100%|██████████| 118/118 [00:01<00:00, 72.43it/s]


EPOCH -  60 . Train Accuracy =  0.8179529309272766 , Validation Accuracy =  0.5099999904632568


100%|██████████| 665/665 [00:08<00:00, 79.09it/s]
100%|██████████| 118/118 [00:01<00:00, 75.58it/s]


EPOCH -  61 . Train Accuracy =  0.8172705769538879 , Validation Accuracy =  0.5178666710853577


100%|██████████| 665/665 [00:09<00:00, 73.71it/s]
100%|██████████| 118/118 [00:01<00:00, 92.32it/s]


EPOCH -  62 . Train Accuracy =  0.8250823616981506 , Validation Accuracy =  0.5286666750907898


100%|██████████| 665/665 [00:11<00:00, 60.37it/s]
100%|██████████| 118/118 [00:01<00:00, 91.69it/s]


EPOCH -  63 . Train Accuracy =  0.8240705728530884 , Validation Accuracy =  0.5086666941642761


100%|██████████| 665/665 [00:09<00:00, 73.61it/s]
100%|██████████| 118/118 [00:01<00:00, 88.22it/s]


EPOCH -  64 . Train Accuracy =  0.8250588178634644 , Validation Accuracy =  0.5144000053405762


100%|██████████| 665/665 [00:09<00:00, 71.84it/s]
100%|██████████| 118/118 [00:01<00:00, 89.68it/s]


EPOCH -  65 . Train Accuracy =  0.8248235583305359 , Validation Accuracy =  0.5040000081062317


100%|██████████| 665/665 [00:09<00:00, 72.23it/s]
100%|██████████| 118/118 [00:01<00:00, 88.90it/s]


EPOCH -  66 . Train Accuracy =  0.82948237657547 , Validation Accuracy =  0.512666642665863


100%|██████████| 665/665 [00:09<00:00, 70.60it/s]
100%|██████████| 118/118 [00:01<00:00, 88.71it/s]


EPOCH -  67 . Train Accuracy =  0.8382117748260498 , Validation Accuracy =  0.5027999877929688


100%|██████████| 665/665 [00:08<00:00, 78.72it/s]
100%|██████████| 118/118 [00:01<00:00, 66.23it/s]


EPOCH -  68 . Train Accuracy =  0.8377882242202759 , Validation Accuracy =  0.5102666616439819


100%|██████████| 665/665 [00:08<00:00, 76.71it/s]
100%|██████████| 118/118 [00:01<00:00, 92.38it/s]


EPOCH -  69 . Train Accuracy =  0.8407294154167175 , Validation Accuracy =  0.5057333111763


100%|██████████| 665/665 [00:09<00:00, 72.18it/s]
100%|██████████| 118/118 [00:01<00:00, 89.89it/s]


EPOCH -  70 . Train Accuracy =  0.8464000225067139 , Validation Accuracy =  0.5150666832923889


100%|██████████| 665/665 [00:09<00:00, 71.97it/s]
100%|██████████| 118/118 [00:01<00:00, 89.97it/s]


EPOCH -  71 . Train Accuracy =  0.8421646952629089 , Validation Accuracy =  0.5089333653450012


100%|██████████| 665/665 [00:09<00:00, 71.03it/s]
100%|██████████| 118/118 [00:01<00:00, 92.26it/s]


EPOCH -  72 . Train Accuracy =  0.8440940976142883 , Validation Accuracy =  0.5090667009353638


100%|██████████| 665/665 [00:09<00:00, 72.24it/s]
100%|██████████| 118/118 [00:01<00:00, 89.04it/s]


EPOCH -  73 . Train Accuracy =  0.8487529754638672 , Validation Accuracy =  0.5260000228881836


100%|██████████| 665/665 [00:08<00:00, 75.11it/s]
100%|██████████| 118/118 [00:01<00:00, 68.22it/s]


EPOCH -  74 . Train Accuracy =  0.8565176725387573 , Validation Accuracy =  0.5073333382606506


100%|██████████| 665/665 [00:08<00:00, 79.74it/s]
100%|██████████| 118/118 [00:01<00:00, 89.51it/s]


EPOCH -  75 . Train Accuracy =  0.8569411635398865 , Validation Accuracy =  0.5077333450317383


100%|██████████| 665/665 [00:09<00:00, 72.88it/s]
100%|██████████| 118/118 [00:01<00:00, 90.25it/s]


EPOCH -  76 . Train Accuracy =  0.8507999777793884 , Validation Accuracy =  0.5078666806221008


100%|██████████| 665/665 [00:09<00:00, 73.06it/s]
100%|██████████| 118/118 [00:01<00:00, 87.59it/s]


EPOCH -  77 . Train Accuracy =  0.8645647168159485 , Validation Accuracy =  0.5080000162124634


100%|██████████| 665/665 [00:09<00:00, 72.02it/s]
100%|██████████| 118/118 [00:01<00:00, 90.60it/s]


EPOCH -  78 . Train Accuracy =  0.8609882593154907 , Validation Accuracy =  0.5026666522026062


100%|██████████| 665/665 [00:09<00:00, 72.23it/s]
100%|██████████| 118/118 [00:01<00:00, 91.15it/s]


EPOCH -  79 . Train Accuracy =  0.8553647398948669 , Validation Accuracy =  0.5122666954994202


100%|██████████| 665/665 [00:08<00:00, 75.62it/s]
100%|██████████| 118/118 [00:01<00:00, 65.20it/s]


EPOCH -  80 . Train Accuracy =  0.8586353063583374 , Validation Accuracy =  0.5203999876976013


100%|██████████| 665/665 [00:08<00:00, 78.72it/s]
100%|██████████| 118/118 [00:01<00:00, 72.84it/s]


EPOCH -  81 . Train Accuracy =  0.8685647249221802 , Validation Accuracy =  0.5070666670799255


100%|██████████| 665/665 [00:09<00:00, 73.25it/s]
100%|██████████| 118/118 [00:01<00:00, 86.99it/s]


EPOCH -  82 . Train Accuracy =  0.8678117990493774 , Validation Accuracy =  0.503333330154419


100%|██████████| 665/665 [00:09<00:00, 72.38it/s]
100%|██████████| 118/118 [00:01<00:00, 89.55it/s]


EPOCH -  83 . Train Accuracy =  0.8693647384643555 , Validation Accuracy =  0.5122666954994202


100%|██████████| 665/665 [00:09<00:00, 72.30it/s]
100%|██████████| 118/118 [00:01<00:00, 88.50it/s]


EPOCH -  84 . Train Accuracy =  0.8723999857902527 , Validation Accuracy =  0.5078666806221008


100%|██████████| 665/665 [00:09<00:00, 72.56it/s]
100%|██████████| 118/118 [00:01<00:00, 89.73it/s]


EPOCH -  85 . Train Accuracy =  0.8736470937728882 , Validation Accuracy =  0.5101333260536194


100%|██████████| 665/665 [00:09<00:00, 71.82it/s]
100%|██████████| 118/118 [00:01<00:00, 88.12it/s]


EPOCH -  86 . Train Accuracy =  0.8802117705345154 , Validation Accuracy =  0.5098666548728943


100%|██████████| 665/665 [00:08<00:00, 75.34it/s]
100%|██████████| 118/118 [00:01<00:00, 59.19it/s]


EPOCH -  87 . Train Accuracy =  0.8799294233322144 , Validation Accuracy =  0.5018666982650757


100%|██████████| 665/665 [00:08<00:00, 77.69it/s]
100%|██████████| 118/118 [00:01<00:00, 83.04it/s]


EPOCH -  88 . Train Accuracy =  0.8753882646560669 , Validation Accuracy =  0.5073333382606506


100%|██████████| 665/665 [00:09<00:00, 71.60it/s]
100%|██████████| 118/118 [00:01<00:00, 89.40it/s]


EPOCH -  89 . Train Accuracy =  0.8765647411346436 , Validation Accuracy =  0.5098666548728943


100%|██████████| 665/665 [00:09<00:00, 71.81it/s]
100%|██████████| 118/118 [00:01<00:00, 90.11it/s]


EPOCH -  90 . Train Accuracy =  0.8813411593437195 , Validation Accuracy =  0.49586668610572815


100%|██████████| 665/665 [00:09<00:00, 71.64it/s]
100%|██████████| 118/118 [00:01<00:00, 89.06it/s]


EPOCH -  91 . Train Accuracy =  0.8762588500976562 , Validation Accuracy =  0.4893333315849304


100%|██████████| 665/665 [00:09<00:00, 71.89it/s]
100%|██████████| 118/118 [00:01<00:00, 89.92it/s]


EPOCH -  92 . Train Accuracy =  0.8832706212997437 , Validation Accuracy =  0.5156000256538391


100%|██████████| 665/665 [00:09<00:00, 71.28it/s]
100%|██████████| 118/118 [00:01<00:00, 86.08it/s]


EPOCH -  93 . Train Accuracy =  0.8892941474914551 , Validation Accuracy =  0.5077333450317383


100%|██████████| 665/665 [00:08<00:00, 78.56it/s]
100%|██████████| 118/118 [00:01<00:00, 63.87it/s]


EPOCH -  94 . Train Accuracy =  0.8923764824867249 , Validation Accuracy =  0.508400022983551


100%|██████████| 665/665 [00:08<00:00, 76.52it/s]
100%|██████████| 118/118 [00:01<00:00, 90.95it/s]


EPOCH -  95 . Train Accuracy =  0.8912235498428345 , Validation Accuracy =  0.5097333192825317


100%|██████████| 665/665 [00:09<00:00, 70.71it/s]
100%|██████████| 118/118 [00:01<00:00, 89.76it/s]


EPOCH -  96 . Train Accuracy =  0.8893647193908691 , Validation Accuracy =  0.5061333179473877


100%|██████████| 665/665 [00:09<00:00, 71.66it/s]
100%|██████████| 118/118 [00:01<00:00, 89.50it/s]


EPOCH -  97 . Train Accuracy =  0.8881646990776062 , Validation Accuracy =  0.5009333491325378


100%|██████████| 665/665 [00:09<00:00, 71.37it/s]
100%|██████████| 118/118 [00:01<00:00, 89.85it/s]


EPOCH -  98 . Train Accuracy =  0.8959059119224548 , Validation Accuracy =  0.5156000256538391


100%|██████████| 665/665 [00:09<00:00, 71.17it/s]
100%|██████████| 118/118 [00:01<00:00, 91.58it/s]


EPOCH -  99 . Train Accuracy =  0.8978588581085205 , Validation Accuracy =  0.506933331489563


100%|██████████| 665/665 [00:09<00:00, 73.52it/s]
100%|██████████| 118/118 [00:01<00:00, 74.91it/s]


EPOCH -  100 . Train Accuracy =  0.8922588229179382 , Validation Accuracy =  0.5129333138465881



100%|██████████| 665/665 [00:07<00:00, 88.46it/s]


Train Accuracy =  0.7432470917701721


100%|██████████| 157/157 [00:01<00:00, 87.90it/s]

Test Accuracy =  0.5155999660491943
Generalization Gap =  0.22764712572097778



