In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
from torchvision import transforms

import numpy as np
from tqdm import tqdm
import math

In [None]:
transform = transforms.Compose([
                    transforms.ToTensor(), 
                    # transforms.Normalize((0.5,), (0.5,)),
            ])

In [None]:
def CreateDataLoaders(Option, p_BatchSize):
    if Option == 1:
        train_dataset = datasets.MNIST("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.MNIST("./", train=False, transform = transform, download=True)
    elif Option == 2:
        train_dataset = datasets.FashionMNIST("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.FashionMNIST("./", train=False, transform = transform, download=True)
    
    elif Option == 5:
        train_dataset = datasets.CIFAR10("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.CIFAR10("./", train=False, transform = transform, download=True)
    
    elif Option == 6:
        train_dataset = datasets.CIFAR100("./", train=True, transform = transform, download=True)
        train_subset, val_subset = torch.utils.data.random_split(train_dataset, [int(len(train_dataset) * 0.85), int(len(train_dataset) * 0.15)], generator=torch.Generator().manual_seed(1))
        test_dataset = datasets.CIFAR100("./", train=False, transform = transform, download=True)

    
    Train_DataLoader = torch.utils.data.DataLoader(train_subset, batch_size = p_BatchSize, shuffle = True)
    Val_DataLoader = torch.utils.data.DataLoader(val_subset, batch_size = p_BatchSize, shuffle = True)
    Test_DataLoader = torch.utils.data.DataLoader(test_dataset, batch_size = p_BatchSize, shuffle = True)

    return Train_DataLoader, Val_DataLoader, Test_DataLoader

In [None]:
class AllConv_IOCN(nn.Module):
    def __init__(self, OutputDim):
        super(AllConv_IOCN, self).__init__()
        # self.InputDim = InputDim
        self.OutputDim = OutputDim
        
        self.Conv1 = nn.Conv2d(3, 96, kernel_size = 3)
        self.Conv2 = nn.Conv2d(96, 96, kernel_size = 3)
        self.Conv3 = nn.Conv2d(96, 96, kernel_size = 3, stride = 2)
        self.Conv4 = nn.Conv2d(96, 192, kernel_size = 3)
        self.Conv5 = nn.Conv2d(192, 192, kernel_size = 3)
        self.Conv6 = nn.Conv2d(192, 192, kernel_size = 3, stride = 2)
        self.Conv7 = nn.Conv2d(192, 192, kernel_size = 3)
        self.Conv8 = nn.Conv2d(192, 192, kernel_size = 1)
        self.Conv9 = nn.Conv2d(192, self.OutputDim, kernel_size = 1)

        self.AvgPool = nn.AvgPool2d(kernel_size = 2) # In Paper it is mentioned to avg pool on 6 x 6

        self.ActFunc = nn.functional.relu
        # self.batch = nn.BatchNorm1d(800)
        # self.SftMax = nn.Softmax

    def forward(self, x):
        # print(x.shape)
        x = self.ActFunc(self.Conv1(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv2(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv3(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv4(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv5(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv6(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv7(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv8(x))
        # print(x.shape)
        x = self.ActFunc(self.Conv9(x))
        # print(x.shape)
        # print("Conv done")
        x = self.AvgPool(x)
        # print(x.shape)
        
        return x

In [None]:
def TrainModel(p_model, loss_criteria, Optimizer, device,  p_TrainDL, Gamma = 5):
    TrainCorr = 0
    TotNumOfSamples = 0
    ValAccuracy = float('inf')

    for images, labels in tqdm(p_TrainDL):
        Optimizer.zero_grad()

        # images = images.reshape(images.shape[0], -1).to(device)
        images = images.to(device)
        labels = labels.to(device)

        pred = p_model(images)
        pred = pred.reshape(pred.shape[0], -1)
        # print(type(pred))
        # print(pred.shape)
        # print(pred)
        # pred=pred.to(torch.float)
        # labels=labels.to(torch.float)
        predClass = torch.max(pred.data, 1)[1]
        # predClass = np.argmax(pred, axis=-1)
        # predClass = pred.max(1)
        # print("Pred", predClass)
        # print(type(labels))
        # print("labels", labels)

        # predClass=predClass.to(torch.float)
        loss = loss_criteria(pred, labels)
        TrainCorr += (predClass==labels).sum()
        TotNumOfSamples += len(labels)
        # print("Correct", TrainCorr, " Out of", TotNumOfSamples)

        for name, param in p_model.named_parameters():
            # if "Linear1" in name or "bias" in name:
            if "Conv1" in name:
                continue
            tmpParam = param.data
            NewTmpParam = torch.exp(tmpParam - Gamma)
            tmpParam = torch.where(tmpParam<0, NewTmpParam, tmpParam)
            param.data = tmpParam

        loss.backward()
        Optimizer.step()
    
    return TrainCorr/TotNumOfSamples

def EvaluateModel(p_model, p_loader, device):
    Correct = 0
    TotalNumOfSamples = 0

    for images, labels in tqdm(p_loader):
        images = images.to(device)
        labels = labels.to(device)

        pred = p_model(images)
        pred = pred.reshape(pred.shape[0], -1)
        predClass = torch.max(pred.data, 1)[1]
        Correct += (predClass==labels).sum()
        TotalNumOfSamples += len(labels)
        # loss = loss_criteria(predClass, labels)
        # Accuracy += loss.cpu().data.item()

    return Correct/TotalNumOfSamples

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 100

In [None]:
# -------------------------------------------- CIFAR-10 Dataset -----------------------------------------------------------
Option = 5
Train_DataLoader, Val_DataLoader, Test_DataLoader = CreateDataLoaders(Option, 64)
ModelName = "Model_AllConv_IOCN_CIFAR10.pt"

model = AllConv_IOCN(10).to(device)
loss_criteria = nn.CrossEntropyLoss()
AdamOpt = torch.optim.Adam(model.parameters(), lr=0.0001)

ValAccuracy = 0

saved = False
if saved == True:
    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)
else:
    model.train()
    Train_Accuracy = 0

    for e in range(EPOCHS):
        Train_Accuracy = TrainModel(model, loss_criteria, AdamOpt, device, Train_DataLoader)
        val_acc = EvaluateModel(model, Val_DataLoader, device)

        print("EPOCH - ", e+1, ". Train Accuracy = ", Train_Accuracy.cpu().item(), ", Validation Accuracy = ", val_acc.cpu().item())
        
        if val_acc.cpu().item() > ValAccuracy:
            print("Model Re-Saved")
            ValAccuracy = val_acc.cpu().item()
            torch.save(model, ModelName)

    saved_model = torch.load(ModelName, map_location=torch.device('cpu')).to(device)


print()

Train_Accuracy = EvaluateModel(saved_model, Train_DataLoader, device)
print("Train Accuracy = ", Train_Accuracy.cpu().item())
Test_Accuracy = EvaluateModel(saved_model, Test_DataLoader, device)
print("Test Accuracy = ", Test_Accuracy.cpu().item())

print("Generalization Gap = ", (Train_Accuracy.cpu().item() - Test_Accuracy.cpu().item()))

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 665/665 [00:10<00:00, 60.67it/s]
100%|██████████| 118/118 [00:01<00:00, 86.30it/s]


EPOCH -  1 . Train Accuracy =  0.1038588210940361 , Validation Accuracy =  0.14239999651908875
Model Re-Saved


100%|██████████| 665/665 [00:10<00:00, 60.77it/s]
100%|██████████| 118/118 [00:01<00:00, 89.86it/s]


EPOCH -  2 . Train Accuracy =  0.10804706066846848 , Validation Accuracy =  0.10920000076293945


100%|██████████| 665/665 [00:11<00:00, 60.23it/s]
100%|██████████| 118/118 [00:01<00:00, 87.53it/s]


EPOCH -  3 . Train Accuracy =  0.11228235065937042 , Validation Accuracy =  0.12746666371822357


100%|██████████| 665/665 [00:11<00:00, 59.16it/s]
100%|██████████| 118/118 [00:01<00:00, 85.09it/s]


EPOCH -  4 . Train Accuracy =  0.10868235677480698 , Validation Accuracy =  0.1112000048160553


100%|██████████| 665/665 [00:11<00:00, 57.14it/s]
100%|██████████| 118/118 [00:01<00:00, 81.39it/s]


EPOCH -  5 . Train Accuracy =  0.11089412122964859 , Validation Accuracy =  0.11906667053699493


100%|██████████| 665/665 [00:11<00:00, 58.68it/s]
100%|██████████| 118/118 [00:01<00:00, 85.35it/s]


EPOCH -  6 . Train Accuracy =  0.11515294015407562 , Validation Accuracy =  0.11400000005960464


100%|██████████| 665/665 [00:11<00:00, 58.58it/s]
100%|██████████| 118/118 [00:01<00:00, 83.05it/s]


EPOCH -  7 . Train Accuracy =  0.11569412052631378 , Validation Accuracy =  0.09626666456460953


100%|██████████| 665/665 [00:11<00:00, 58.98it/s]
100%|██████████| 118/118 [00:01<00:00, 84.80it/s]


EPOCH -  8 . Train Accuracy =  0.11804705858230591 , Validation Accuracy =  0.11840000003576279


100%|██████████| 665/665 [00:11<00:00, 59.43it/s]
100%|██████████| 118/118 [00:01<00:00, 85.57it/s]


EPOCH -  9 . Train Accuracy =  0.1179058849811554 , Validation Accuracy =  0.10013333708047867


100%|██████████| 665/665 [00:11<00:00, 59.47it/s]
100%|██████████| 118/118 [00:01<00:00, 85.38it/s]


EPOCH -  10 . Train Accuracy =  0.10894117504358292 , Validation Accuracy =  0.10733333230018616


100%|██████████| 665/665 [00:11<00:00, 58.91it/s]
100%|██████████| 118/118 [00:01<00:00, 87.88it/s]


EPOCH -  11 . Train Accuracy =  0.10070588439702988 , Validation Accuracy =  0.10946666449308395


100%|██████████| 665/665 [00:11<00:00, 59.40it/s]
100%|██████████| 118/118 [00:01<00:00, 85.88it/s]


EPOCH -  12 . Train Accuracy =  0.10604705661535263 , Validation Accuracy =  0.12133333832025528


100%|██████████| 665/665 [00:11<00:00, 59.32it/s]
100%|██████████| 118/118 [00:01<00:00, 87.27it/s]


EPOCH -  13 . Train Accuracy =  0.1124705895781517 , Validation Accuracy =  0.109333336353302


100%|██████████| 665/665 [00:11<00:00, 59.36it/s]
100%|██████████| 118/118 [00:01<00:00, 84.10it/s]


EPOCH -  14 . Train Accuracy =  0.11496470868587494 , Validation Accuracy =  0.09520000219345093


100%|██████████| 665/665 [00:11<00:00, 56.32it/s]
100%|██████████| 118/118 [00:01<00:00, 87.27it/s]


EPOCH -  15 . Train Accuracy =  0.11642353236675262 , Validation Accuracy =  0.09613333642482758


100%|██████████| 665/665 [00:11<00:00, 59.51it/s]
100%|██████████| 118/118 [00:01<00:00, 83.24it/s]


EPOCH -  16 . Train Accuracy =  0.11181176453828812 , Validation Accuracy =  0.09666667133569717


100%|██████████| 665/665 [00:11<00:00, 58.76it/s]
100%|██████████| 118/118 [00:01<00:00, 88.67it/s]


EPOCH -  17 . Train Accuracy =  0.10091764479875565 , Validation Accuracy =  0.12666666507720947


100%|██████████| 665/665 [00:11<00:00, 59.50it/s]
100%|██████████| 118/118 [00:01<00:00, 88.09it/s]


EPOCH -  18 . Train Accuracy =  0.10407058894634247 , Validation Accuracy =  0.09546666592359543


100%|██████████| 665/665 [00:11<00:00, 59.45it/s]
100%|██████████| 118/118 [00:01<00:00, 85.57it/s]


EPOCH -  19 . Train Accuracy =  0.10762353241443634 , Validation Accuracy =  0.09666667133569717


100%|██████████| 665/665 [00:11<00:00, 59.19it/s]
100%|██████████| 118/118 [00:01<00:00, 85.10it/s]


EPOCH -  20 . Train Accuracy =  0.1052941158413887 , Validation Accuracy =  0.09453333169221878


100%|██████████| 665/665 [00:11<00:00, 59.23it/s]
100%|██████████| 118/118 [00:01<00:00, 78.03it/s]


EPOCH -  21 . Train Accuracy =  0.10028235614299774 , Validation Accuracy =  0.11293333768844604


100%|██████████| 665/665 [00:11<00:00, 60.24it/s]
100%|██████████| 118/118 [00:01<00:00, 69.67it/s]


EPOCH -  22 . Train Accuracy =  0.09967058897018433 , Validation Accuracy =  0.10013333708047867


100%|██████████| 665/665 [00:11<00:00, 59.99it/s]
100%|██████████| 118/118 [00:01<00:00, 64.09it/s]


EPOCH -  23 . Train Accuracy =  0.10152941197156906 , Validation Accuracy =  0.10093333572149277


100%|██████████| 665/665 [00:10<00:00, 61.39it/s]
100%|██████████| 118/118 [00:01<00:00, 60.56it/s]


EPOCH -  24 . Train Accuracy =  0.0985882356762886 , Validation Accuracy =  0.10013333708047867


100%|██████████| 665/665 [00:10<00:00, 61.33it/s]
100%|██████████| 118/118 [00:01<00:00, 64.72it/s]


EPOCH -  25 . Train Accuracy =  0.10251764953136444 , Validation Accuracy =  0.14480000734329224
Model Re-Saved


100%|██████████| 665/665 [00:10<00:00, 60.79it/s]
100%|██████████| 118/118 [00:01<00:00, 70.15it/s]


EPOCH -  26 . Train Accuracy =  0.1255764663219452 , Validation Accuracy =  0.14346666634082794


100%|██████████| 665/665 [00:11<00:00, 60.14it/s]
100%|██████████| 118/118 [00:01<00:00, 76.38it/s]


EPOCH -  27 . Train Accuracy =  0.1370352953672409 , Validation Accuracy =  0.13093332946300507


100%|██████████| 665/665 [00:11<00:00, 59.56it/s]
100%|██████████| 118/118 [00:01<00:00, 85.81it/s]


EPOCH -  28 . Train Accuracy =  0.14522352814674377 , Validation Accuracy =  0.16120000183582306
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 58.70it/s]
100%|██████████| 118/118 [00:01<00:00, 86.92it/s]


EPOCH -  29 . Train Accuracy =  0.14872941374778748 , Validation Accuracy =  0.13626666367053986


100%|██████████| 665/665 [00:11<00:00, 59.45it/s]
100%|██████████| 118/118 [00:01<00:00, 86.13it/s]


EPOCH -  30 . Train Accuracy =  0.1467764675617218 , Validation Accuracy =  0.1687999963760376
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 60.12it/s]
100%|██████████| 118/118 [00:01<00:00, 89.46it/s]


EPOCH -  31 . Train Accuracy =  0.1448470652103424 , Validation Accuracy =  0.09960000216960907


100%|██████████| 665/665 [00:11<00:00, 60.05it/s]
100%|██████████| 118/118 [00:01<00:00, 87.31it/s]


EPOCH -  32 . Train Accuracy =  0.14301176369190216 , Validation Accuracy =  0.12640000879764557


100%|██████████| 665/665 [00:11<00:00, 57.22it/s]
100%|██████████| 118/118 [00:01<00:00, 86.38it/s]


EPOCH -  33 . Train Accuracy =  0.1485176533460617 , Validation Accuracy =  0.133733332157135


100%|██████████| 665/665 [00:11<00:00, 59.61it/s]
100%|██████████| 118/118 [00:01<00:00, 88.26it/s]


EPOCH -  34 . Train Accuracy =  0.1477411836385727 , Validation Accuracy =  0.10159999877214432


100%|██████████| 665/665 [00:11<00:00, 59.67it/s]
100%|██████████| 118/118 [00:01<00:00, 86.71it/s]


EPOCH -  35 . Train Accuracy =  0.15049411356449127 , Validation Accuracy =  0.1720000058412552
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 60.24it/s]
100%|██████████| 118/118 [00:01<00:00, 88.07it/s]


EPOCH -  36 . Train Accuracy =  0.15350587666034698 , Validation Accuracy =  0.11426667124032974


100%|██████████| 665/665 [00:11<00:00, 59.97it/s]
100%|██████████| 118/118 [00:01<00:00, 88.27it/s]


EPOCH -  37 . Train Accuracy =  0.15409411489963531 , Validation Accuracy =  0.12453333288431168


100%|██████████| 665/665 [00:11<00:00, 59.20it/s]
100%|██████████| 118/118 [00:01<00:00, 86.61it/s]


EPOCH -  38 . Train Accuracy =  0.15329411625862122 , Validation Accuracy =  0.14786666631698608


100%|██████████| 665/665 [00:11<00:00, 59.83it/s]
100%|██████████| 118/118 [00:01<00:00, 85.74it/s]


EPOCH -  39 . Train Accuracy =  0.15736471116542816 , Validation Accuracy =  0.14573334157466888


100%|██████████| 665/665 [00:11<00:00, 59.72it/s]
100%|██████████| 118/118 [00:01<00:00, 85.75it/s]


EPOCH -  40 . Train Accuracy =  0.15910588204860687 , Validation Accuracy =  0.1404000073671341


100%|██████████| 665/665 [00:11<00:00, 59.06it/s]
100%|██████████| 118/118 [00:01<00:00, 86.52it/s]


EPOCH -  41 . Train Accuracy =  0.1606588214635849 , Validation Accuracy =  0.13146667182445526


100%|██████████| 665/665 [00:11<00:00, 59.71it/s]
100%|██████████| 118/118 [00:01<00:00, 90.17it/s]


EPOCH -  42 . Train Accuracy =  0.15995293855667114 , Validation Accuracy =  0.15146666765213013


100%|██████████| 665/665 [00:11<00:00, 59.60it/s]
100%|██████████| 118/118 [00:01<00:00, 81.86it/s]


EPOCH -  43 . Train Accuracy =  0.16150587797164917 , Validation Accuracy =  0.15960000455379486


100%|██████████| 665/665 [00:11<00:00, 59.19it/s]
100%|██████████| 118/118 [00:01<00:00, 87.96it/s]


EPOCH -  44 . Train Accuracy =  0.1598588228225708 , Validation Accuracy =  0.14866666495800018


100%|██████████| 665/665 [00:11<00:00, 59.45it/s]
100%|██████████| 118/118 [00:01<00:00, 87.93it/s]


EPOCH -  45 . Train Accuracy =  0.16094118356704712 , Validation Accuracy =  0.171466663479805


100%|██████████| 665/665 [00:11<00:00, 59.71it/s]
100%|██████████| 118/118 [00:01<00:00, 83.64it/s]


EPOCH -  46 . Train Accuracy =  0.16635294258594513 , Validation Accuracy =  0.1717333346605301


100%|██████████| 665/665 [00:11<00:00, 59.10it/s]
100%|██████████| 118/118 [00:01<00:00, 86.81it/s]


EPOCH -  47 . Train Accuracy =  0.16124705970287323 , Validation Accuracy =  0.17866666615009308
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.90it/s]
100%|██████████| 118/118 [00:01<00:00, 88.18it/s]


EPOCH -  48 . Train Accuracy =  0.16498823463916779 , Validation Accuracy =  0.15266667306423187


100%|██████████| 665/665 [00:11<00:00, 60.08it/s]
100%|██████████| 118/118 [00:01<00:00, 87.40it/s]


EPOCH -  49 . Train Accuracy =  0.16505882143974304 , Validation Accuracy =  0.1409333348274231


100%|██████████| 665/665 [00:11<00:00, 57.19it/s]
100%|██████████| 118/118 [00:01<00:00, 84.33it/s]


EPOCH -  50 . Train Accuracy =  0.1654588282108307 , Validation Accuracy =  0.12506666779518127


100%|██████████| 665/665 [00:11<00:00, 59.88it/s]
100%|██████████| 118/118 [00:01<00:00, 85.14it/s]


EPOCH -  51 . Train Accuracy =  0.16663528978824615 , Validation Accuracy =  0.14426666498184204


100%|██████████| 665/665 [00:11<00:00, 58.99it/s]
100%|██████████| 118/118 [00:01<00:00, 80.74it/s]


EPOCH -  52 . Train Accuracy =  0.16931764781475067 , Validation Accuracy =  0.15533334016799927


100%|██████████| 665/665 [00:11<00:00, 58.42it/s]
100%|██████████| 118/118 [00:01<00:00, 71.61it/s]


EPOCH -  53 . Train Accuracy =  0.1660941243171692 , Validation Accuracy =  0.1720000058412552


100%|██████████| 665/665 [00:11<00:00, 60.12it/s]
100%|██████████| 118/118 [00:01<00:00, 61.72it/s]


EPOCH -  54 . Train Accuracy =  0.17098823189735413 , Validation Accuracy =  0.15733332931995392


100%|██████████| 665/665 [00:10<00:00, 61.04it/s]
100%|██████████| 118/118 [00:01<00:00, 60.58it/s]


EPOCH -  55 . Train Accuracy =  0.17169411480426788 , Validation Accuracy =  0.16760000586509705


100%|██████████| 665/665 [00:10<00:00, 60.89it/s]
100%|██████████| 118/118 [00:01<00:00, 65.86it/s]


EPOCH -  56 . Train Accuracy =  0.1676941215991974 , Validation Accuracy =  0.16093333065509796


100%|██████████| 665/665 [00:11<00:00, 59.62it/s]
100%|██████████| 118/118 [00:01<00:00, 73.65it/s]


EPOCH -  57 . Train Accuracy =  0.17512941360473633 , Validation Accuracy =  0.19413334131240845
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.63it/s]
100%|██████████| 118/118 [00:01<00:00, 86.36it/s]


EPOCH -  58 . Train Accuracy =  0.17536470293998718 , Validation Accuracy =  0.13040000200271606


100%|██████████| 665/665 [00:11<00:00, 58.83it/s]
100%|██████████| 118/118 [00:01<00:00, 86.82it/s]


EPOCH -  59 . Train Accuracy =  0.17578823864459991 , Validation Accuracy =  0.11479999870061874


100%|██████████| 665/665 [00:11<00:00, 59.37it/s]
100%|██████████| 118/118 [00:01<00:00, 86.02it/s]


EPOCH -  60 . Train Accuracy =  0.17625881731510162 , Validation Accuracy =  0.164000004529953


100%|██████████| 665/665 [00:11<00:00, 59.48it/s]
100%|██████████| 118/118 [00:01<00:00, 87.68it/s]


EPOCH -  61 . Train Accuracy =  0.1754823476076126 , Validation Accuracy =  0.12319999933242798


100%|██████████| 665/665 [00:11<00:00, 59.74it/s]
100%|██████████| 118/118 [00:01<00:00, 84.75it/s]


EPOCH -  62 . Train Accuracy =  0.17287059128284454 , Validation Accuracy =  0.14400000870227814


100%|██████████| 665/665 [00:11<00:00, 59.76it/s]
100%|██████████| 118/118 [00:01<00:00, 84.61it/s]


EPOCH -  63 . Train Accuracy =  0.17496471107006073 , Validation Accuracy =  0.13893333077430725


100%|██████████| 665/665 [00:11<00:00, 59.32it/s]
100%|██████████| 118/118 [00:01<00:00, 87.77it/s]


EPOCH -  64 . Train Accuracy =  0.17635294795036316 , Validation Accuracy =  0.17466667294502258


100%|██████████| 665/665 [00:11<00:00, 59.02it/s]
100%|██████████| 118/118 [00:01<00:00, 85.25it/s]


EPOCH -  65 . Train Accuracy =  0.1802823543548584 , Validation Accuracy =  0.18346667289733887


100%|██████████| 665/665 [00:11<00:00, 59.34it/s]
100%|██████████| 118/118 [00:01<00:00, 88.70it/s]


EPOCH -  66 . Train Accuracy =  0.17679999768733978 , Validation Accuracy =  0.13840000331401825


100%|██████████| 665/665 [00:11<00:00, 59.81it/s]
100%|██████████| 118/118 [00:01<00:00, 86.16it/s]


EPOCH -  67 . Train Accuracy =  0.1814117729663849 , Validation Accuracy =  0.14813333749771118


100%|██████████| 665/665 [00:11<00:00, 56.59it/s]
100%|██████████| 118/118 [00:01<00:00, 81.96it/s]


EPOCH -  68 . Train Accuracy =  0.17870588600635529 , Validation Accuracy =  0.18199999630451202


100%|██████████| 665/665 [00:11<00:00, 59.31it/s]
100%|██████████| 118/118 [00:01<00:00, 85.05it/s]


EPOCH -  69 . Train Accuracy =  0.17983528971672058 , Validation Accuracy =  0.1674666702747345


100%|██████████| 665/665 [00:11<00:00, 59.18it/s]
100%|██████████| 118/118 [00:01<00:00, 84.72it/s]


EPOCH -  70 . Train Accuracy =  0.1793411821126938 , Validation Accuracy =  0.1889333426952362


100%|██████████| 665/665 [00:11<00:00, 59.21it/s]
100%|██████████| 118/118 [00:01<00:00, 86.25it/s]


EPOCH -  71 . Train Accuracy =  0.18637646734714508 , Validation Accuracy =  0.1679999977350235


100%|██████████| 665/665 [00:11<00:00, 59.41it/s]
100%|██████████| 118/118 [00:01<00:00, 86.30it/s]


EPOCH -  72 . Train Accuracy =  0.18000000715255737 , Validation Accuracy =  0.1701333373785019


100%|██████████| 665/665 [00:11<00:00, 59.61it/s]
100%|██████████| 118/118 [00:01<00:00, 86.63it/s]


EPOCH -  73 . Train Accuracy =  0.18524706363677979 , Validation Accuracy =  0.1616000086069107


100%|██████████| 665/665 [00:11<00:00, 59.73it/s]
100%|██████████| 118/118 [00:01<00:00, 87.66it/s]


EPOCH -  74 . Train Accuracy =  0.18305882811546326 , Validation Accuracy =  0.16600000858306885


100%|██████████| 665/665 [00:11<00:00, 59.20it/s]
100%|██████████| 118/118 [00:01<00:00, 88.60it/s]


EPOCH -  75 . Train Accuracy =  0.18562352657318115 , Validation Accuracy =  0.1950666755437851
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.26it/s]
100%|██████████| 118/118 [00:01<00:00, 87.25it/s]


EPOCH -  76 . Train Accuracy =  0.18148235976696014 , Validation Accuracy =  0.18346667289733887


100%|██████████| 665/665 [00:11<00:00, 59.06it/s]
100%|██████████| 118/118 [00:01<00:00, 89.47it/s]


EPOCH -  77 . Train Accuracy =  0.18807059526443481 , Validation Accuracy =  0.14000000059604645


100%|██████████| 665/665 [00:11<00:00, 59.57it/s]
100%|██████████| 118/118 [00:01<00:00, 86.56it/s]


EPOCH -  78 . Train Accuracy =  0.1852235347032547 , Validation Accuracy =  0.18573333323001862


100%|██████████| 665/665 [00:11<00:00, 59.21it/s]
100%|██████████| 118/118 [00:01<00:00, 85.26it/s]


EPOCH -  79 . Train Accuracy =  0.18912941217422485 , Validation Accuracy =  0.20053333044052124
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.43it/s]
100%|██████████| 118/118 [00:01<00:00, 87.68it/s]


EPOCH -  80 . Train Accuracy =  0.18774117529392242 , Validation Accuracy =  0.15760000050067902


100%|██████████| 665/665 [00:11<00:00, 59.10it/s]
100%|██████████| 118/118 [00:01<00:00, 84.84it/s]


EPOCH -  81 . Train Accuracy =  0.1884235292673111 , Validation Accuracy =  0.15919999778270721


100%|██████████| 665/665 [00:11<00:00, 59.09it/s]
100%|██████████| 118/118 [00:01<00:00, 87.56it/s]


EPOCH -  82 . Train Accuracy =  0.18976470828056335 , Validation Accuracy =  0.1881333291530609


100%|██████████| 665/665 [00:11<00:00, 59.02it/s]
100%|██████████| 118/118 [00:01<00:00, 76.84it/s]


EPOCH -  83 . Train Accuracy =  0.19122353196144104 , Validation Accuracy =  0.1674666702747345


100%|██████████| 665/665 [00:10<00:00, 60.49it/s]
100%|██████████| 118/118 [00:01<00:00, 74.88it/s]


EPOCH -  84 . Train Accuracy =  0.1912941187620163 , Validation Accuracy =  0.17599999904632568


100%|██████████| 665/665 [00:11<00:00, 57.57it/s]
100%|██████████| 118/118 [00:01<00:00, 62.25it/s]


EPOCH -  85 . Train Accuracy =  0.19327059388160706 , Validation Accuracy =  0.15839999914169312


100%|██████████| 665/665 [00:10<00:00, 61.76it/s]
100%|██████████| 118/118 [00:01<00:00, 63.34it/s]


EPOCH -  86 . Train Accuracy =  0.19228234887123108 , Validation Accuracy =  0.16466666758060455


100%|██████████| 665/665 [00:10<00:00, 61.57it/s]
100%|██████████| 118/118 [00:01<00:00, 67.37it/s]


EPOCH -  87 . Train Accuracy =  0.1928941160440445 , Validation Accuracy =  0.1910666674375534


100%|██████████| 665/665 [00:10<00:00, 61.11it/s]
100%|██████████| 118/118 [00:01<00:00, 73.25it/s]


EPOCH -  88 . Train Accuracy =  0.193529412150383 , Validation Accuracy =  0.15280000865459442


100%|██████████| 665/665 [00:11<00:00, 59.65it/s]
100%|██████████| 118/118 [00:01<00:00, 81.04it/s]


EPOCH -  89 . Train Accuracy =  0.19087058305740356 , Validation Accuracy =  0.18146666884422302


100%|██████████| 665/665 [00:11<00:00, 59.96it/s]
100%|██████████| 118/118 [00:01<00:00, 88.62it/s]


EPOCH -  90 . Train Accuracy =  0.19604705274105072 , Validation Accuracy =  0.18880000710487366


100%|██████████| 665/665 [00:11<00:00, 59.69it/s]
100%|██████████| 118/118 [00:01<00:00, 86.04it/s]


EPOCH -  91 . Train Accuracy =  0.1936705857515335 , Validation Accuracy =  0.1350666731595993


100%|██████████| 665/665 [00:11<00:00, 59.55it/s]
100%|██████████| 118/118 [00:01<00:00, 85.01it/s]


EPOCH -  92 . Train Accuracy =  0.19872941076755524 , Validation Accuracy =  0.19066667556762695


100%|██████████| 665/665 [00:11<00:00, 59.75it/s]
100%|██████████| 118/118 [00:01<00:00, 87.84it/s]


EPOCH -  93 . Train Accuracy =  0.19632941484451294 , Validation Accuracy =  0.18453334271907806


100%|██████████| 665/665 [00:11<00:00, 59.12it/s]
100%|██████████| 118/118 [00:01<00:00, 87.05it/s]


EPOCH -  94 . Train Accuracy =  0.19950588047504425 , Validation Accuracy =  0.21119999885559082
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.32it/s]
100%|██████████| 118/118 [00:01<00:00, 87.21it/s]


EPOCH -  95 . Train Accuracy =  0.2035764753818512 , Validation Accuracy =  0.20720000565052032


100%|██████████| 665/665 [00:11<00:00, 59.88it/s]
100%|██████████| 118/118 [00:01<00:00, 86.55it/s]


EPOCH -  96 . Train Accuracy =  0.2027764767408371 , Validation Accuracy =  0.2136000096797943
Model Re-Saved


100%|██████████| 665/665 [00:11<00:00, 59.93it/s]
100%|██████████| 118/118 [00:01<00:00, 88.14it/s]


EPOCH -  97 . Train Accuracy =  0.2032705843448639 , Validation Accuracy =  0.18199999630451202


100%|██████████| 665/665 [00:11<00:00, 59.61it/s]
100%|██████████| 118/118 [00:01<00:00, 87.14it/s]


EPOCH -  98 . Train Accuracy =  0.20360000431537628 , Validation Accuracy =  0.19466666877269745


100%|██████████| 665/665 [00:11<00:00, 60.04it/s]
100%|██████████| 118/118 [00:01<00:00, 87.71it/s]


EPOCH -  99 . Train Accuracy =  0.20482352375984192 , Validation Accuracy =  0.1934666633605957


100%|██████████| 665/665 [00:11<00:00, 59.95it/s]
100%|██████████| 118/118 [00:01<00:00, 88.76it/s]


EPOCH -  100 . Train Accuracy =  0.20465882122516632 , Validation Accuracy =  0.1982666701078415



100%|██████████| 665/665 [00:08<00:00, 78.46it/s]


Train Accuracy =  0.21279999613761902


100%|██████████| 157/157 [00:01<00:00, 88.86it/s]

Test Accuracy =  0.21329998970031738
Generalization Gap =  -0.0004999935626983643





In [None]:

model = AllConv_IOCN(10).to(device)
for name, param in model.named_parameters():
    print(name)

Conv1.weight
Conv1.bias
Conv2.weight
Conv2.bias
Conv3.weight
Conv3.bias
Conv4.weight
Conv4.bias
Conv5.weight
Conv5.bias
Conv6.weight
Conv6.bias
Conv7.weight
Conv7.bias
Conv8.weight
Conv8.bias
Conv9.weight
Conv9.bias
