In [71]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.init as weight_init
import matplotlib.pyplot as plt
import pdb
import torch.nn.functional as F


#parameters
batch_size = 128

preprocess = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])

#Loading the train set file
dataset = datasets.MNIST(root='./data',
                            transform=preprocess,  
                            download=True)

loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [72]:
class AE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,20),
        )
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )
    
    def forward(self,x):
        h = self.encoder(x)
        xr = self.decoder(h)
        return xr,h


In [73]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
criterion = nn.MSELoss()
learning_rate = 1e-2
weight_decay = 1e-5
net = AE()
net = net.to(device)
optimizer = torch.optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [74]:
#Training
def train(num_epochs = 50):
    epochLoss = []
    for epoch in range(num_epochs):
        total_loss, cntr = 0, 0

        for i,(images,_) in enumerate(loader):

            images = images.view(-1, 28*28)
            images = images.to(device)

            # Initialize gradients to 0
            optimizer.zero_grad()

            # Forward pass (this calls the "forward" function within Net)
            outputs, _ = net(images)

            # Find the loss
            loss = criterion(outputs, images)

            # Find the gradients of all weights using the loss
            loss.backward()

            # Update the weights using the optimizer and scheduler
            optimizer.step()

            total_loss += loss.item()
            cntr += 1

    #     scheduler.step(total_loss/cntr)
        print ('Epoch [%d/%d], Loss: %.4f' 
                       %(epoch+1, num_epochs, total_loss/cntr))
        epochLoss.append(total_loss/cntr)
    return epochLoss

In [76]:
train()

Epoch [1/50], Loss: 1.4745
Epoch [2/50], Loss: 1.3888
Epoch [3/50], Loss: 1.5239
Epoch [4/50], Loss: 1.3672
Epoch [5/50], Loss: 1.3687
Epoch [6/50], Loss: 1.6003
Epoch [7/50], Loss: 1.4203
Epoch [8/50], Loss: 1.4933
Epoch [9/50], Loss: 1.5200
Epoch [10/50], Loss: 1.4593
Epoch [11/50], Loss: 1.2394
Epoch [12/50], Loss: 0.8446
Epoch [13/50], Loss: 1.0230
Epoch [14/50], Loss: 1.4936
Epoch [15/50], Loss: 1.3484
Epoch [16/50], Loss: 1.3810
Epoch [17/50], Loss: 1.4634
Epoch [18/50], Loss: 1.4559
Epoch [19/50], Loss: 1.4815
Epoch [20/50], Loss: 1.4375
Epoch [21/50], Loss: 1.2742
Epoch [22/50], Loss: 1.1295
Epoch [23/50], Loss: 0.7118
Epoch [24/50], Loss: 0.7170
Epoch [25/50], Loss: 0.7114
Epoch [26/50], Loss: 0.8895
Epoch [27/50], Loss: 0.7115
Epoch [28/50], Loss: 0.7917
Epoch [29/50], Loss: 0.7112
Epoch [30/50], Loss: 0.7295
Epoch [31/50], Loss: 0.7112
Epoch [32/50], Loss: 0.7123
Epoch [33/50], Loss: 0.7197
Epoch [34/50], Loss: 0.7112
Epoch [35/50], Loss: 0.7113
Epoch [36/50], Loss: 0.7323
E

[1.4745078707046346,
 1.3887574123675381,
 1.5239177581343823,
 1.3671875106754587,
 1.368736672757277,
 1.6003218147037888,
 1.42033065204173,
 1.493271048389264,
 1.5199946632771604,
 1.4592721210613941,
 1.2394373127138183,
 0.8445918735410614,
 1.0229771491815287,
 1.4936184916160762,
 1.3484401438536167,
 1.3809959954544426,
 1.4633906899230567,
 1.4559181420279463,
 1.4815368332079988,
 1.437520218937636,
 1.2742042259366781,
 1.1295095768564545,
 0.7117902983480425,
 0.7170050585193675,
 0.7113781751854333,
 0.8894749297770356,
 0.7115073354005306,
 0.7916703736349973,
 0.7111923535749602,
 0.7294961988036313,
 0.7111657167802742,
 0.7123258491314804,
 0.7197357710998958,
 0.7111622841119258,
 0.7112965727411608,
 0.7322640399943029,
 0.7112574574789767,
 0.711305244009632,
 1.1502524912992775,
 0.7753150047523889,
 0.5529984953179796,
 0.5254777169812208,
 0.5059298743952566,
 0.49704156863664006,
 0.4916747366187415,
 0.4884732321762581,
 0.48537342571246345,
 0.48444446762487

In [91]:
class FullNet(nn.Module):
    def __init__(self):
        super(FullNet, self).__init__()
        self.fc1 = nn.Linear(784, 320)
        self.fc2 = nn.Linear(320, 50)
        self.fc3 = nn.Linear(50, 10)
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x,dim=1)

class RedNet(nn.Module):
    def __init__(self):
        super(RedNet, self).__init__()
        self.fc1 = nn.Linear(20, 15)
        self.fc2 = nn.Linear(15, 12)
        self.fc3 = nn.Linear(12, 10)
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x,dim=1)

In [92]:
fnet = FullNet()
fnet.to(device)
rnet = RedNet()
rnet.to(device)

optimizer_fnet = torch.optim.SGD(fnet.parameters(), lr=learning_rate)
optimizer_rnet = torch.optim.SGD(rnet.parameters(), lr=learning_rate)
criterion_fnet = nn.CrossEntropyLoss()
criterion_rnet = nn.CrossEntropyLoss()

In [99]:
def train_full(num_epochs = 50):
    epochLoss = []
    for epoch in range(num_epochs):
        total_loss, cntr = 0, 0

        for i,(images,labels) in enumerate(loader):

            images = images.view(-1, 28*28)
            images,labels = images.to(device),labels.to(device)

            # Initialize gradients to 0
            optimizer_fnet.zero_grad()

            # Forward pass (this calls the "forward" function within Net)
            full, red = net(images)
            out_fnet = fnet(full)
            
            # Find the loss
            loss_full = criterion_fnet(out_fnet,labels)
            
            # Find the gradients of all weights using the loss
            loss_full.backward()
            
            # Update the weights using the optimizer and scheduler
            optimizer_fnet.step()
            
            total_loss += loss_full.item()
            cntr += 1

    #     scheduler.step(total_loss/cntr)
        print ('Epoch [%d/%d], Loss: %.4f' 
                       %(epoch+1, num_epochs, total_loss/cntr))
        epochLoss.append(total_loss/cntr)
    return epochLoss

In [100]:
train_full()

Epoch [1/50], Loss: 2.2808
Epoch [2/50], Loss: 2.2158
Epoch [3/50], Loss: 2.0771
Epoch [4/50], Loss: 1.9404
Epoch [5/50], Loss: 1.8302
Epoch [6/50], Loss: 1.7381
Epoch [7/50], Loss: 1.6914
Epoch [8/50], Loss: 1.6686
Epoch [9/50], Loss: 1.6548
Epoch [10/50], Loss: 1.6452
Epoch [11/50], Loss: 1.6379
Epoch [12/50], Loss: 1.6322
Epoch [13/50], Loss: 1.6275
Epoch [14/50], Loss: 1.6237
Epoch [15/50], Loss: 1.6204
Epoch [16/50], Loss: 1.6177
Epoch [17/50], Loss: 1.6153
Epoch [18/50], Loss: 1.6130
Epoch [19/50], Loss: 1.6111
Epoch [20/50], Loss: 1.6094
Epoch [21/50], Loss: 1.6078
Epoch [22/50], Loss: 1.6063
Epoch [23/50], Loss: 1.6049
Epoch [24/50], Loss: 1.6037
Epoch [25/50], Loss: 1.6024
Epoch [26/50], Loss: 1.6014
Epoch [27/50], Loss: 1.6004
Epoch [28/50], Loss: 1.5995
Epoch [29/50], Loss: 1.5985
Epoch [30/50], Loss: 1.5976
Epoch [31/50], Loss: 1.5969
Epoch [32/50], Loss: 1.5960
Epoch [33/50], Loss: 1.5953
Epoch [34/50], Loss: 1.5946
Epoch [35/50], Loss: 1.5940
Epoch [36/50], Loss: 1.5934
E

[2.2808221931904873,
 2.215758320110947,
 2.0771129243155277,
 1.9404132132337037,
 1.8302344304920515,
 1.7381436903593637,
 1.6914477020438545,
 1.6685973433797547,
 1.6548013003396074,
 1.6451801178551941,
 1.6379338421547083,
 1.6322340004479707,
 1.6275464250588976,
 1.6236702658728497,
 1.6204166676698208,
 1.6176703583711245,
 1.6152534988134908,
 1.613047860832865,
 1.6110934224972593,
 1.609395239398932,
 1.6077661178767808,
 1.606256331716265,
 1.6049362278696317,
 1.6036807565546747,
 1.6024412613179384,
 1.601430276563681,
 1.600378854442507,
 1.5994562884129442,
 1.5985317591156787,
 1.5976496934890747,
 1.596874540548589,
 1.5960401759218814,
 1.5952564798184294,
 1.5945969606513408,
 1.5940281124765685,
 1.5933633519134034,
 1.5928026394549208,
 1.5921963468543503,
 1.5916354897688192,
 1.591051027464714,
 1.590519185259398,
 1.5900276993383478,
 1.5895246585040712,
 1.5890038610775588,
 1.5885506827694011,
 1.5881324993776107,
 1.5876412925435537,
 1.5872652421373803,
 

In [103]:
def train_red(num_epochs = 50):
    epochLoss = []
    for epoch in range(num_epochs):
        total_loss, cntr = 0, 0

        for i,(images,labels) in enumerate(loader):

            images = images.view(-1, 28*28)
            images,labels = images.to(device),labels.to(device)

            # Initialize gradients to 0
            optimizer_rnet.zero_grad()

            # Forward pass (this calls the "forward" function within Net)
            full, red = net(images)
            out_rnet = rnet(red)
            
            # Find the loss
            loss_red = criterion_rnet(out_rnet,labels)
            
            # Find the gradients of all weights using the loss
            loss_red.backward()
            
            # Update the weights using the optimizer and scheduler
            optimizer_rnet.step()
            
            total_loss += loss_red.item()
            cntr += 1

    #     scheduler.step(total_loss/cntr)
        print ('Epoch [%d/%d], Loss: %.4f' 
                       %(epoch+1, num_epochs, total_loss/cntr))
        epochLoss.append(total_loss/cntr)
    return epochLoss

In [None]:
train_red()

Epoch [1/50], Loss: 2.2855
Epoch [2/50], Loss: 2.1918
Epoch [3/50], Loss: 2.1208
Epoch [4/50], Loss: 2.0744
Epoch [5/50], Loss: 2.0460
Epoch [6/50], Loss: 1.9941
Epoch [7/50], Loss: 1.9579
Epoch [8/50], Loss: 1.9338
