In [38]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from torch.autograd import Variable
import numpy as np
from torchvision import transforms
import os
import torch.optim as optim
import torch.nn.functional as F

In [39]:
config = {
    "num_epochs":50,
    "lr":1e-3,
    "regular_constant":1e-5,
    "batch_size" : 128,
    "train_transform":transforms.Compose([
        transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
    ]),
    "test_transform":transforms.Compose([
        transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ]), 
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [40]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()
        self.encoder = nn.Sequential( 
            nn.Conv2d(in_channels = 1, out_channels = 28, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(28),
            nn.ReLU(),
            nn.Conv2d(in_channels = 28, out_channels = 32, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            )

        self.decoder = nn.Sequential( 
            nn.ConvTranspose2d(in_channels = 32, out_channels = 28, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(28),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 28, out_channels = 1, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            )

    def forward(self,x):
        # print("the original is: ",x.shape)
        encoder = self.encoder(x)
        decoder = self.decoder(encoder)
        # print("after decorder ",decoder.shape)
        return decoder

In [41]:
# # !pip install torchsummary

from torchsummary import summary
summary(AutoEncoder().to(device), input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 28, 25, 25]             476
       BatchNorm2d-2           [-1, 28, 25, 25]              56
              ReLU-3           [-1, 28, 25, 25]               0
            Conv2d-4           [-1, 32, 22, 22]          14,368
       BatchNorm2d-5           [-1, 32, 22, 22]              64
              ReLU-6           [-1, 32, 22, 22]               0
   ConvTranspose2d-7           [-1, 28, 25, 25]          14,364
       BatchNorm2d-8           [-1, 28, 25, 25]              56
              ReLU-9           [-1, 28, 25, 25]               0
  ConvTranspose2d-10            [-1, 1, 28, 28]             449
      BatchNorm2d-11            [-1, 1, 28, 28]               2
             ReLU-12            [-1, 1, 28, 28]               0
Total params: 29,835
Trainable params: 29,835
Non-trainable params: 0
---------------------------------

In [42]:
def train(train_dataloader, validate_dataloader, device, config, path):
    model = AutoEncoder().to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=config["lr"],
        betas=(0.9, 0.999),
        weight_decay=config["regular_constant"],
    )
    train_loss_value = []
    validate_loss_value = []
    train_accuracy = []
    validate_accuracy = []
    current_epoch = []
    # acc = 0.0
    low_loss = torch.tensor(float('inf')).cuda()
    for epoch in range(config["num_epochs"]):
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        current_epoch.append(epoch + 1)
        print("####### Training Processing #######")
        print("in epoch: ", epoch + 1)
        if epoch > 20:
            optimizer = optim.Adam(
                model.parameters(),
                lr=config["lr"] * 0.5,
                betas=(0.9, 0.999),
                weight_decay=config["regular_constant"],
            )
        if epoch > 30:
            optimizer = optim.Adam(
                model.parameters(),
                lr=config["lr"] * 0.2,
                betas=(0.9, 0.999),
                weight_decay=config["regular_constant"],
            )

        if epoch > 35:
            optimizer = optim.Adam(
                model.parameters(),
                lr=config["lr"] * 0.1,
                betas=(0.9, 0.999),
                weight_decay=config["regular_constant"],
            )

        for batch_idx, (inputs, targets) in enumerate(train_dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            # print(inputs.shape,targets.shape)
            # print("###########")
            optimizer.zero_grad()
            outputs = model(inputs)
            # print(outputs.shape)
            loss = loss_function(outputs, inputs)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            # _, predicted = outputs.max(1)
            total += targets.size(0)
            # correct += predicted.eq(targets).sum().item()
        train_loss /= len(train_dataloader.dataset)
        # print(
            # "Training set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
                # train_loss, correct, total, 100.0 * correct / total
            # )
        # )
        print("Training set: Avg. loss: {:.6f}".format(train_loss))
        train_loss_value.append(train_loss)
        # train_accuracy.append(100.0 * (correct / total))

        # Validation step
        model.eval()
        validation_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in validate_dataloader:
                data, target = data.to(device), target.to(device)
                output = model(data) 
                # pred = output.data.max(1, keepdim=True)[1]
                # correct += pred.eq(target.data.view_as(pred)).sum().item()
                validation_loss += loss_function(output, data).item()
            validation_loss /= len(validate_dataloader.dataset)
            # print(
            #     "\nValidation set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            #         validation_loss,
            #         correct,
            #         len(validate_dataloader.dataset),
            #         100.0 * correct / len(validate_dataloader.dataset),
            #     )
            # )
            print("\nValidation set: Avg. loss: {:.6f}".format(validation_loss))
        # current_acc = 100.0 * correct / len(validate_dataloader.dataset)

        validate_loss_value.append(validation_loss)
        # validate_accuracy.append(current_acc)
        # if current_acc > acc:
        #     acc = current_acc
        if validation_loss < low_loss:
            low_loss = validation_loss
        
            torch.save(model.state_dict(), os.path.join(path, "ckpt.pth"))
            print("model save at checkpoint")

    plt.plot(current_epoch, train_loss_value, "b", label="Training Loss")
    plt.plot(current_epoch, validate_loss_value, "r", label="Validation Loss")
    plt.title("Loss v.s. Epochs")
    plt.legend()
    plt.savefig(os.path.join(path, "loss_curve.jpg"))
    plt.figure()

    # plt.plot(current_epoch, train_accuracy, "b", label="Training Accuracy")
    # plt.plot(current_epoch, validate_accuracy, "r", label="Validation Accuracy")
    # plt.title("Accuracy v.s. Epochs")
    # plt.legend()
    # plt.savefig(os.path.join(path, "accuracy.jpg"))
    # plt.show()
    return model

In [47]:
def test(test_dataloader, model, device):
    test_predictions = []
    true_labels = []
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_dataloader:
            data,target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_function(output, data)
            test_loss += loss.item()
            # pred = output.data.max(1, keepdim=True)[1]
            # test_predictions.append(pred[0])
            # true_labels.append(target.data.view_as(pred)[0])
            # correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_dataloader.dataset)
        # print(
        #     "Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
        #         test_loss,
        #         correct,
        #         len(test_dataloader.dataset),
        #         100.0 * correct.item() / len(test_dataloader.dataset),
        #     )
        # )
        print("\nTest set: Avg. loss: {:.6f}".format(test_loss))
        return test_predictions, true_labels

In [44]:
train_data = torchvision.datasets.MNIST(
    root='./data/', 
    train=True,
    transform = config["train_transform"],  
    download=True,  
)

test_data = torchvision.datasets.MNIST(
    root='./data/', 
    train=False, 
    transform = config["test_transform"],  
    download=True,  
)
training_set, validation_set = torch.utils.data.random_split(train_data, [int(len(train_data)*0.8), int(len(train_data)*0.2)])

train_loader = Data.DataLoader(
    training_set,
    batch_size = config["batch_size"],
    shuffle = True,
    ) 

vali_loader = Data.DataLoader(
    validation_set,
    batch_size = config["batch_size"],
    shuffle = True,
    ) 

test_loader  = Data.DataLoader(
    test_data,
    batch_size = config["batch_size"],
    shuffle = False,
)

In [48]:
loss_function = nn.MSELoss()
if os.path.exists("ckpt.pth"):
    checkpoint = torch.load("ckpt.pth", map_location=device)
    model = AutoEncoder().to(device)
    model.load_state_dict(checkpoint)
else:
    model = train(train_loader, vali_loader, device, config, os.getcwd())
# model = train(train_loader,vali_loader,device,config,path = os.getcwd())
test(test_loader,model,device)


Test set: Avg. loss: 0.001159


([], [])