In [1]:
import timeit
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix
import numpy as np
import copy
import matplotlib.pyplot as plt


def lmform(
    data_list,
    config,
    join,
    transfer=None):
    
    start = timeit.default_timer()
    
    np.random.seed(seed=config["seed"])
    
    convergence_parameters = {}
    convergence_parameters["count"] = 0
    convergence_parameters["score_vec"] = [10e6]
    
    main_parameters, main_code = initialise_lmform(data_list = data_list,
                                         config = config,
                                         join = join,
                                         transfer = transfer
                                        )
    
    if (config["verbose"]):
            print("Beginning lmform learning with:    Sample dimension reduction (config[i_dim]): " + str( config["i_dim"] ) + "    Feature dimension reduction (config[j_dim]): " + str( config["j_dim"] ) + "    Tolerance Threshold: " + str( config["tol"] ) + "   Maximum number of iterations: "  + str( config["max_iter"] ) + "   Verbose: ", config["verbose"])

            
    while True:
        prev_encode = main_code["encode"]
        
        for i in range(len(join["complete"]["data_list"])):
            internal_parameters = {}
            internal_parameters["beta"] = main_parameters["beta"][join["complete"]["beta"][i]]
            internal_parameters["intercept"] = main_parameters["intercept"][join["complete"]["data_list"][i]]
            
            internal_code = {}
            internal_code["encode"] = main_code["encode"][join["complete"]["code"][i]]
            internal_code["code"] = main_code["code"][join["complete"]["code"][i]]
            
            return_parameters, return_code  = update_set_lmform( 
                                        x = data_list[join["complete"]["data_list"][i]],
                                        main_parameters = internal_parameters,
                                        main_code = internal_code,
                                        config = config,
                                        fix = transfer["fix"]
                                        )

            main_parameters["beta"][join["complete"]["beta"][i]] = internal_parameters["beta"]
            main_parameters["intercept"][join["complete"]["data_list"][i]] = internal_parameters["intercept"]
            
            main_code["code"][join["complete"]["code"][i]] = internal_code["code"]
            main_code["encode"][join["complete"]["code"][i]] = internal_code["encode"]
            
        total_mae = 0
        for X in range(len(join["complete"]["data_list"])):      
            total_mae += torch.mean(torch.abs(main_code["encode"][join["complete"]["code"][X]] - prev_encode[join["complete"]["code"][X]]))

        # Check convergence
        convergence_parameters["score_vec"] += [total_mae]
        MSE = convergence_parameters["score_vec"][-1]
        prev_MSE = convergence_parameters["score_vec"][-2]
        
        if convergence_parameters["count"]>=1:
            if config["verbose"]:
                print("Iteration:   "+str(convergence_parameters["count"])+"   with Tolerance of:   "+str(abs(prev_MSE - MSE)))
            if convergence_parameters["count"] >= config["max_iter"]:
                break
            if abs(prev_MSE - MSE) < config["tol"]:
                break
        convergence_parameters["count"] += 1

    if (config["verbose"]):
        print("Learning has converged for lmform, beginning (if requested) dimension reduction")

    return_data = {}
    return_data["main_parameters"] = main_parameters
    return_data["main_code"] = main_code
    return_data["meta_parameters"] = {}
    return_data["meta_parameters"]["config"] = config
    return_data["meta_parameters"]["join"] = join
    return_data["convergence_parameters"] = convergence_parameters
    
    stop = timeit.default_timer()

    return_data["run_time"] = {}
    return_data["run_time"]["start"] = start
    return_data["run_time"]["stop"] = stop
    return_data["run_time"]["run_time"] = stop - start
    
    return return_data
               
               
def initialise_lmform(
    data_list,
    config,
    join,
    transfer
):

    main_code = {}
    main_code["code"] = {}
    main_code["encode"] = {}

    main_parameters = {}
    main_parameters["beta"] = {}
    main_parameters["intercept"] = {}
    
    for i in range(len(join["complete"]["data_list"])):
        main_code["code"][join["complete"]["code"][i]] = []
        main_code["encode"][join["complete"]["code"][i]] = []

        main_parameters["beta"][join["complete"]["beta"][i]] = []
        main_parameters["intercept"][join["complete"]["data_list"][i]] = []
    

    for i in range(len(join["complete"]["data_list"])):

        if main_parameters["beta"][join["complete"]["beta"][i]] == []:
            if not len(transfer["main_parameters"]["beta"][join["complete"]["beta"][i]]) == 0:
                main_parameters["beta"][join["complete"]["beta"][i]] = transfer["main_parameters"]["beta"][join["complete"]["beta"][i]]
            else:
                main_parameters["beta"][join["complete"]["beta"][i]] = initialise_parameters_lmform(x = data_list[join["complete"]["data_list"][i]], dim_main = config["j_dim"], seed_main = 1, type_main = config["init"]["beta"]).T

        if main_code["code"][join["complete"]["encode"][i]] == []:
            if not len(transfer["main_code"]["encode"][join["complete"]["encode"][i]]) == 0:
                main_code["encode"][join["complete"]["code"][i]] = transfer["main_code"]["encode"][join["complete"]["code"][i]]
            else:
                main_code["encode"][join["complete"]["code"][i]] = data_list[join["complete"]["data_list"][i]]@main_parameters["beta"][join["complete"]["beta"][i]]

        if main_code["code"][join["complete"]["code"][i]] == []:
            if not len(transfer["main_code"]["code"][join["complete"]["code"][i]]) == 0:
                main_code["code"][join["complete"]["code"][i]] = transfer["main_code"]["code"][join["complete"]["code"][i]]
            else:
                main_code["code"][join["complete"]["code"][i]] = data_list[join["complete"]["data_list"][i]]@torch.linalg.pinv((main_parameters["beta"][join["complete"]["beta"][i]]).T)

        if main_parameters["intercept"][join["complete"]["data_list"][i]] == []:
            if not len(transfer["main_parameters"]["intercept"][join["complete"]["data_list"][i]]) == 0:
                main_parameters["intercept"][join["complete"]["data_list"][i]] = transfer["main_parameters"]["intercept"][join["complete"]["data_list"][i]]
            else:
                main_parameters["intercept"][join["complete"]["data_list"][i]] = torch.mean(data_list[join["complete"]["data_list"][i]] - main_code["code"][join["complete"]["code"][i]]@(main_parameters["beta"][join["complete"]["beta"][i]].T),0)
        
    return main_parameters, main_code
               
               
               
def initialise_parameters_lmform(
                            x,
                            dim_main, 
                            seed_main,
                            type_main):
    if type_main == "SVD":
        svd = TruncatedSVD(n_components=dim_main, n_iter=2, random_state=seed_main)
        return svd.fit(x).components_

    
    if type_main == "rand":
        rand_data = np.random.randn(dim_main,x.shape[1])
        return rand_data
    
    
               
def update_set_lmform(x,main_parameters,main_code,config,fix):

    if not fix["code"]:
        main_code["code"] = (x - main_parameters["intercept"])@torch.linalg.pinv(main_parameters["beta"].T)
               
    if not fix["beta"]:
        main_parameters["beta"] = (torch.linalg.pinv(main_code["code"])@(x - main_parameters["intercept"])).T
               
    if not fix["intercept"]:
        main_parameters["intercept"] = torch.mean(x - (main_code["code"])@(main_parameters["beta"]).T,0)
    
    if not fix["encode"]:
        main_code["encode"] = (x - main_parameters["intercept"])@(main_parameters["beta"])

    return main_parameters, main_code

In [2]:
import timeit
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import csr_matrix
import numpy as np
import copy
import matplotlib.pyplot as plt

def ecdf(x,dims):
    xas = np.argsort(x.flatten())
    xs = np.sort(x.flatten())
    xfs = ((np.arange(1, len(xs)+1)/float(len(xs)))[xas]).reshape(dims)
    return xfs

def imshow(img):
    img = ecdf(img)
    plt.imshow(img)
    plt.show()

def gcode(
    data_list,
    config,
    join,
    transfer=None):
    
    start = timeit.default_timer()
    
    np.random.seed(seed=config["seed"])
    
    convergence_parameters = {}
    convergence_parameters["count"] = 0
    convergence_parameters["score_vec"] = [10e6]
    
    main_parameters, main_code = initialise_gcode(data_list = data_list,
                                         config = config,
                                         join = join,
                                         transfer = transfer
                                        )
    
    if (config["verbose"]):
            print("Beginning gcode learning with:    Sample dimension reduction (config[i_dim]): " + str( config["i_dim"] ) + "    Feature dimension reduction (config[j_dim]): " + str( config["j_dim"] ) + "    Tolerance Threshold: " + str( config["tol"] ) + "   Maximum number of iterations: "  + str( config["max_iter"] ) + "   Verbose: ", config["verbose"])

            
    while True:
        prev_encode = main_code["encode"]
        
        for i in range(len(join["complete"]["data_list"])):

            internal_parameters = {}
            internal_parameters["alpha"] = main_parameters["alpha"][join["complete"]["alpha"][i]]
            internal_parameters["beta"] = main_parameters["beta"][join["complete"]["beta"][i]]
            internal_parameters["intercept"] = main_parameters["intercept"][join["complete"]["data_list"][i]]

            internal_code = {}
            internal_code["encode"] = main_code["encode"][join["complete"]["code"][i]]
            internal_code["code"] = main_code["code"][join["complete"]["code"][i]]

            return_parameters, return_code  = update_set_gcode( 
                                        x = data_list[join["complete"]["data_list"][i]],
                                        main_parameters = internal_parameters,
                                        main_code = internal_code,
                                        config = config,
                                        fix = transfer["fix"]
                                        )
            main_parameters["alpha"][join["complete"]["alpha"][i]] = internal_parameters["alpha"]
            main_parameters["beta"][join["complete"]["beta"][i]] = internal_parameters["beta"]
            main_parameters["intercept"][join["complete"]["data_list"][i]] = internal_parameters["intercept"]

            main_code["code"][join["complete"]["code"][i]] = internal_code["code"]
            main_code["encode"][join["complete"]["code"][i]] = internal_code["encode"]
            
        total_mae = 0
        for X in range(len(join["complete"]["data_list"])):      
            total_mae += torch.mean(torch.abs(main_code["encode"][join["complete"]["code"][X]] - prev_encode[join["complete"]["code"][X]]))

        # Check convergence
        convergence_parameters["score_vec"] += [total_mae]
        MSE = convergence_parameters["score_vec"][-1]
        prev_MSE = convergence_parameters["score_vec"][-2]
        
        if convergence_parameters["count"]>=1:
            if config["verbose"]:
                print("Iteration:   "+str(convergence_parameters["count"])+"   with Tolerance of:   "+str(abs(prev_MSE - MSE)))
            if convergence_parameters["count"] >= config["max_iter"]:
                break
            if abs(prev_MSE - MSE) < config["tol"]:
                break
        convergence_parameters["count"] += 1

    if (config["verbose"]):
        print("Learning has converged for gcode, beginning (if requested) dimension reduction")

    return_data = {}
    return_data["main_parameters"] = main_parameters
    return_data["main_code"] = main_code
    return_data["meta_parameters"] = {}
    return_data["meta_parameters"]["config"] = config
    return_data["meta_parameters"]["join"] = join
    return_data["convergence_parameters"] = convergence_parameters
    
    stop = timeit.default_timer()

    return_data["run_time"] = {}
    return_data["run_time"]["start"] = start
    return_data["run_time"]["stop"] = stop
    return_data["run_time"]["run_time"] = stop - start
    
    return return_data
               
               
def initialise_gcode(
    data_list,
    config,
    join,
    transfer
):

    main_code = {}
    main_code["code"] = {}
    main_code["encode"] = {}

    main_parameters = {}
    main_parameters["alpha"] = {}
    main_parameters["beta"] = {}
    main_parameters["intercept"] = {}
    
    for i in range(len(join["complete"]["data_list"])):
        main_code["code"][join["complete"]["code"][i]] = []
        main_code["encode"][join["complete"]["code"][i]] = []

        main_parameters["alpha"][join["complete"]["alpha"][i]] = []
        main_parameters["beta"][join["complete"]["beta"][i]] = []
        main_parameters["intercept"][join["complete"]["data_list"][i]] = []
    

    for i in range(len(join["complete"]["data_list"])):

        if main_parameters["alpha"][join["complete"]["alpha"][i]] == []:
            if not len(transfer["main_parameters"]["alpha"][join["complete"]["alpha"][i]]) == 0:
                main_parameters["alpha"][join["complete"]["alpha"][i]] = transfer["main_parameters"]["alpha"][join["complete"]["alpha"][i]]
            else:
                main_parameters["alpha"][join["complete"]["alpha"][i]] = initialise_parameters_gcode(x = data_list[join["complete"]["data_list"][i]].T, dim_main = config["i_dim"], seed_main = 1, type_main = config["init"]["alpha"])

        if main_parameters["beta"][join["complete"]["beta"][i]] == []:
            if not len(transfer["main_parameters"]["beta"][join["complete"]["beta"][i]]) == 0:
                main_parameters["beta"][join["complete"]["beta"][i]] = transfer["main_parameters"]["beta"][join["complete"]["beta"][i]]
            else:
                main_parameters["beta"][join["complete"]["beta"][i]] = initialise_parameters_gcode(x = data_list[join["complete"]["data_list"][i]], dim_main = config["j_dim"], seed_main = 1, type_main = config["init"]["beta"]).T

        if main_code["encode"][join["complete"]["encode"][i]] == []:
            if not len(transfer["main_code"]["encode"][join["complete"]["encode"][i]]) == 0:
                main_code["encode"][join["complete"]["encode"][i]] = transfer["main_code"]["encode"][join["complete"]["encode"][i]]
            else:
                main_code["encode"][join["complete"]["encode"][i]] = main_parameters["alpha"][join["complete"]["alpha"][i]]@data_list[join["complete"]["data_list"][i]]@main_parameters["beta"][join["complete"]["beta"][i]]

        if main_code["code"][join["complete"]["code"][i]] == []:
            if not len(transfer["main_code"]["code"][join["complete"]["code"][i]]) == 0:
                main_code["code"][join["complete"]["code"][i]] = transfer["main_code"]["code"][join["complete"]["code"][i]]
            else:
                main_code["code"][join["complete"]["code"][i]] = torch.linalg.pinv(main_parameters["alpha"][join["complete"]["alpha"][i]].T)@data_list[join["complete"]["data_list"][i]]@torch.linalg.pinv((main_parameters["beta"][join["complete"]["beta"][i]]).T)

        if main_parameters["intercept"][join["complete"]["data_list"][i]] == []:
            if not len(transfer["main_parameters"]["intercept"][join["complete"]["data_list"][i]]) == 0:
                main_parameters["intercept"][join["complete"]["data_list"][i]] = transfer["main_parameters"]["intercept"][join["complete"]["data_list"][i]]
            else:
                main_parameters["intercept"][join["complete"]["data_list"][i]] = torch.mean(data_list[join["complete"]["data_list"][i]] - (main_parameters["alpha"][join["complete"]["alpha"][i]].T)@main_code["code"][join["complete"]["code"][i]]@(main_parameters["beta"][join["complete"]["beta"][i]].T),0)
        
    return main_parameters, main_code
               
               
               
def initialise_parameters_gcode(
                            x,
                            dim_main, 
                            seed_main,
                            type_main):
    if type_main == "SVD":
        u,s,v = torch.svd(x)
        return v

    
    if type_main == "rand":
        rand_data = torch.randn(dim_main,x.shape[1])
        return rand_data
    
                 
               
               
def update_set_gcode(x,main_parameters,main_code,config,fix):
    
    if not fix["code"]:
        main_code["code"] = torch.linalg.pinv(main_parameters["alpha"].T)@(x - main_parameters["intercept"])@torch.linalg.pinv(main_parameters["beta"].T)

    if not fix["alpha"]:
        main_parameters["alpha"] = ((x - main_parameters["intercept"])@torch.linalg.pinv(main_code["code"]@(main_parameters["beta"].T))).T

    if not fix["beta"]:
        main_parameters["beta"] = (torch.linalg.pinv(main_parameters["alpha"].T@main_code["code"])@(x - main_parameters["intercept"])).T
                                    
    if not fix["intercept"]:
        main_parameters["intercept"] = torch.mean(x - (main_parameters["alpha"].T)@(main_code["code"])@(main_parameters["beta"].T),0)

    if not fix["encode"]:
        main_code["encode"] = (main_parameters["alpha"])@(x - main_parameters["intercept"])@(main_parameters["beta"])

    return main_parameters, main_code


In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
config = {}
config["init"] =  {}
config["init"]["alpha"] = "rand"
config["init"]["beta"] = "rand"
config["max_iter"] =  3
config["verbose"] =  False
config["seed"] =  1
config["tol"] =  1
       
join = {}
join["complete"] = {}
join["complete"]["data_list"] = ["data_1"]
join["complete"]["alpha"] = ["data_1"]
join["complete"]["beta"] = ["data_1"]
join["complete"]["code"] = ["data_1"]
join["complete"]["encode"] = ["data_1"]

transfer = {}
transfer["fix"] = {}
transfer["fix"]["alpha"] = False
transfer["fix"]["beta"] = False
transfer["fix"]["code"] = False
transfer["fix"]["intercept"] = False
transfer["fix"]["encode"] = False

transfer["main_code"] = {}
transfer["main_code"]["code"] = {}
transfer["main_code"]["encode"] = {}

transfer["main_parameters"] = {}
transfer["main_parameters"]["alpha"] = {}
transfer["main_parameters"]["beta"] = {}
transfer["main_parameters"]["intercept"] = {}


for i in range(len(join["complete"]["data_list"])):
    transfer["main_code"]["code"][join["complete"]["code"][i]] = []
    transfer["main_code"]["encode"][join["complete"]["encode"][i]] = []

    transfer["main_parameters"]["alpha"][join["complete"]["alpha"][i]] = []
    transfer["main_parameters"]["beta"][join["complete"]["beta"][i]] = []
    transfer["main_parameters"]["intercept"][join["complete"]["data_list"][i]] = []


In [6]:
from torch.nn.functional import one_hot

batch_size = 32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=1)

testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=True, num_workers=1)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [13]:
import torch.nn as nn
import torch.nn.functional as F


class summary_layer(nn.Module):
    def __init__(self, J_dim, common_dim):
        super(summary_layer, self).__init__()

        self.code_func_c = nn.Linear(J_dim,common_dim)
        self.code_func_I = nn.Linear(common_dim,common_dim)

        self.beta_I = torch.eye(J_dim,J_dim)
        self.beta_func1 = nn.Linear(J_dim,common_dim)
        
    def forward(self, x, init, gradient):

        B,D = x.shape

        beta = self.beta_func1(self.beta_I)
        code = (self.code_func_I((self.code_func_c(x))))
        
        if init:
        
            data_list = {}
            data_list["data_1"] = x

            if gradient:
                transfer["main_parameters"]["beta"]["data_1"] = beta
                transfer["main_code"]["encode"]["data_1"] = transfer["main_code"]["code"]["data_1"] = code

                transfer["fix"]["beta"] = False
                transfer["fix"]["code"] = False
                transfer["fix"]["encode"] = False
            else:
                transfer["main_parameters"]["beta"]["data_1"] = torch.randn(beta.shape)
                transfer["main_code"]["encode"]["data_1"] = transfer["main_code"]["code"]["data_1"] = torch.randn(code.shape)

                transfer["fix"]["beta"] = False
                transfer["fix"]["code"] = False
                transfer["fix"]["encode"] = False

                
            config["j_dim"] = beta.shape[1]

            lmform_model = lmform(
                data_list = data_list,
                config = config,
                join = join,
                transfer = transfer
            )

            beta = lmform_model["main_parameters"]["beta"]["data_1"]

        x = x@beta
        return x

    
    
class concept_layer(nn.Module):
    def __init__(self, J_dim, j_dim, I_dim, i_dim):
        super(concept_layer, self).__init__()
        
        self.j_dim = j_dim
        self.i_dim = i_dim

        self.code_I = torch.ones(i_dim,j_dim)
        self.code_func_j = nn.Linear(j_dim,j_dim)
        self.code_func_i = nn.Linear(i_dim,i_dim)

        self.beta_func1 = nn.Linear(J_dim*J_dim*3,J_dim)
        self.beta_func2 = nn.Linear(J_dim*J_dim*3,j_dim)

        self.alpha_func1 = nn.Linear(I_dim*I_dim*3,I_dim)
        self.alpha_func2 = nn.Linear(I_dim*I_dim*3,i_dim)

    def forward(self, x, init, gradient):

        B,C,H,W = x.shape

        x_flat = torch.flatten(x,1)
        x_cov = (x_flat.T@x_flat)
        alpha = self.alpha_func2(self.alpha_func1(x_flat.T@x_flat).T)
        beta = self.beta_func2(self.beta_func1((x_flat.T@x_flat)).T)
        code = (self.code_func_j(self.code_func_i(self.code_I.T).T))

        if init:
            if gradient:
                transfer["main_parameters"]["alpha"]["data_1"] = alpha
                transfer["main_parameters"]["beta"]["data_1"] = beta
                transfer["main_code"]["encode"]["data_1"] = transfer["main_code"]["code"]["data_1"] = code

                transfer["fix"]["alpha"] = False
                transfer["fix"]["beta"] = False
                transfer["fix"]["code"] = False
                transfer["fix"]["encode"] = False
            else:
                transfer["main_parameters"]["alpha"]["data_1"] = torch.randn(alpha.shape)
                transfer["main_parameters"]["beta"]["data_1"] = torch.randn(beta.shape)
                transfer["main_code"]["encode"]["data_1"] = transfer["main_code"]["code"]["data_1"] = torch.randn(code.shape)

                transfer["fix"]["alpha"] = False
                transfer["fix"]["beta"] = False
                transfer["fix"]["code"] = False
                transfer["fix"]["encode"] = False

            for iteration in range(3):
                for b in range(B):
                    for c in range(C):
                        data_list = {}
                        data_list["data_1"] = x[b,c,:,:]

                    config["i_dim"] = alpha.shape[0]
                    config["j_dim"] = beta.shape[1]
                    
                    gcode_model = gcode(
                        data_list = data_list,
                        config = config,
                        join = join,
                        transfer = transfer
                    )

                    transfer["main_parameters"]["alpha"]["data_1"] = gcode_model["main_parameters"]["alpha"]["data_1"]
                    transfer["main_parameters"]["beta"]["data_1"] = gcode_model["main_parameters"]["beta"]["data_1"]
                    transfer["main_code"]["encode"]["data_1"] = gcode_model["main_code"]["encode"]["data_1"]
                    transfer["main_code"]["code"]["data_1"] = gcode_model["main_code"]["code"]["data_1"]

                    transfer["fix"]["alpha"] = False
                    transfer["fix"]["beta"] = False
                    transfer["fix"]["code"] = False
                    transfer["fix"]["encode"] = False
                    
            alpha = self.dropout6(transfer["main_parameters"]["alpha"]["data_1"])
            beta = self.dropout4(transfer["main_parameters"]["beta"]["data_1"])

        x_new = torch.zeros([B,C,self.i_dim,self.j_dim])
        for b in range(B):
            for c in range(C):
                x_new[b,c,:,:] = alpha@x[b,c,:,:]@beta
                
        return x_new
    
        
    
class ConceptNet(nn.Module):
    def __init__(self,num_classes):
        super(ConceptNet,self).__init__()
        a = 5
        self.conv1 = nn.Conv2d(3, 12, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(12, 32, 5)
        self.summary_layer1 = summary_layer(32*5*5, 500)
        self.summary_layer2 = summary_layer(500,10)

    def forward(self, x, init = True, gradient = False):
        x = self.pool((self.conv1(x)))
        x = self.pool((self.conv2(x)))
        x = self.summary_layer1(torch.flatten(x,1),init,gradient)
        x = self.summary_layer2(x,init,gradient)
        return x
    
    
    

def ConceptNet_simple(num_classes):
    return ConceptNet(num_classes)
    

In [14]:
import torch.optim as optim

net = ConceptNet_simple(10)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum = 0.9)

In [15]:
import time

start = time.time()

results = []

for epoch in range(30):  # loop over the dataset multiple times
    count=0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        count=i
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs, False, True)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        # print statistics

    print(f'[{epoch + 1}, {count + 1:5d}] loss: {running_loss  :.10f}')
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            # calculate outputs by running images through the network
            outputs = net(images,False, False)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')   
    
    
    end = time.time()
    print(end - start)
    results += [[epoch,end-start,running_loss,100*correct/total]]
    running_loss = 0.0
print('Finished Training')

end = time.time()
print(end - start)

[1,  1563] loss: 2837.4730811119
Accuracy of the network on the 10000 test images: 43.03 %
156.8312439918518
[2,  1563] loss: 2357.3210983872
Accuracy of the network on the 10000 test images: 49.86 %
337.47950983047485
[3,  1563] loss: 2182.3548301458
Accuracy of the network on the 10000 test images: 53.04 %
510.62689423561096
[4,  1563] loss: 2040.0143188238
Accuracy of the network on the 10000 test images: 53.82 %
689.2952921390533
[5,  1563] loss: 1939.0150303841
Accuracy of the network on the 10000 test images: 56.0 %
866.0457944869995
[6,  1563] loss: 1864.8996770978
Accuracy of the network on the 10000 test images: 58.04 %
1047.4660449028015
[7,  1563] loss: 1811.8484335840
Accuracy of the network on the 10000 test images: 59.36 %
1204.4575395584106
[8,  1563] loss: 1757.6548793912
Accuracy of the network on the 10000 test images: 60.26 %
1346.8315725326538
[9,  1563] loss: 1717.3020564914
Accuracy of the network on the 10000 test images: 61.68 %
1484.166115283966
[10,  1563] los