In [9]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

In [10]:
import wandb

wandb.login()

run = wandb.init(
    # Set the project where this run will be logged
    project="my-awesome-project",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": 0.001,
        "epochs": 10,
    },
)



0,1
loss,█▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.02026


In [11]:
import optuna

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()

# Download the MNIST Dataset
dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=tensor_transform)

# DataLoader is used to load the dataset for training
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)

In [13]:
# Creating a PyTorch class
# 28*28 ==> 8 ==> 28*28
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 8
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8),
        )

        # Building an linear decoder with Linear
        # layer followed by Relu activation function
        # The Sigmoid activation function
        # outputs the value between 0 and 1
        # 8 ==> 784
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid(),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [14]:
# Model Initialization
model = AE()

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()

# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

In [15]:
# epochs = 10
# outputs = []
# losses = []
# for epoch in range(epochs):
#     print("Epoch: ", epoch)
#     for image, _ in loader:

#         # Reshaping the image to (-1, 784)
#         image = image.reshape(-1, 28 * 28)

#         # Output of Autoencoder
#         reconstructed = model(image)

#         # Calculating the loss function
#         loss = loss_function(reconstructed, image)

#         # The gradients are set to zero,
#         # the gradient is computed and stored.
#         # .step() performs parameter update
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         # Storing the losses in a list for plotting
#         losses.append(loss.item())

#         wandb.log({"loss": loss})


#     outputs.append((epochs, image, reconstructed))
#     print("Loss", loss.item())

# optuna
def train_model(trial):
    epochs = trial.suggest_int("epochs", 1, 10)
    outputs = []
    losses = []
    
    for epoch in range(epochs):
        print("Epoch: ", epoch)
        for image, _ in loader:

            # Reshaping the image to (-1, 784)
            image = image.reshape(-1, 28 * 28)

            # Output of Autoencoder
            reconstructed = model(image)

            # Calculating the loss function
            loss = loss_function(reconstructed, image)

            # The gradients are set to zero,
            # the gradient is computed and stored.
            # .step() performs parameter update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Storing the losses in a list for plotting
            losses.append(loss.item())
            wandb.log({"loss": loss})

        outputs.append((epoch, image, reconstructed))
        print("Loss", loss.item())
    
    return losses[-1]


study = optuna.create_study(direction="minimize")
study.optimize(train_model, n_trials=3)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print("  Params:")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

[I 2024-11-03 17:33:50,382] A new study created in memory with name: no-name-3af73de4-c52c-47ea-80f3-328d4c6fb259


Epoch:  0
Loss 0.03336319699883461
Epoch:  1
Loss 0.02909882925450802
Epoch:  2
Loss 0.021884839981794357
Epoch:  3
Loss 0.02209015190601349
Epoch:  4
Loss 0.022297341376543045
Epoch:  5
Loss 0.021847238764166832
Epoch:  6


[I 2024-11-03 17:35:12,001] Trial 0 finished with value: 0.022891433909535408 and parameters: {'epochs': 7}. Best is trial 0 with value: 0.022891433909535408.


Loss 0.022891433909535408
Epoch:  0
Loss 0.022220324724912643
Epoch:  1
Loss 0.021570995450019836
Epoch:  2
Loss 0.016582416370511055
Epoch:  3
Loss 0.019029689952731133
Epoch:  4
Loss 0.020977642387151718
Epoch:  5
Loss 0.020268399268388748
Epoch:  6
Loss 0.019667472690343857
Epoch:  7
Loss 0.015592856332659721
Epoch:  8
Loss 0.01601189188659191
Epoch:  9


[I 2024-11-03 17:37:11,564] Trial 1 finished with value: 0.01766280084848404 and parameters: {'epochs': 10}. Best is trial 1 with value: 0.01766280084848404.


Loss 0.01766280084848404
Epoch:  0


[I 2024-11-03 17:37:23,602] Trial 2 finished with value: 0.01946624554693699 and parameters: {'epochs': 1}. Best is trial 1 with value: 0.01766280084848404.


Loss 0.01946624554693699
Best trial:
  Value: 0.01766280084848404
  Params:
    epochs: 10


In [16]:
# # Defining the Plot Style
# plt.style.use("fivethirtyeight")
# plt.xlabel("Iterations")
# plt.ylabel("Loss")

# # Plotting the last 100 values
# plt.plot(losses[-100:])

In [17]:
# example = image[:5].cpu().detach().numpy()
# reconstructed_example = reconstructed[:5].cpu().detach().numpy()

# for i in range(5):
#     plt.subplot(2, 5, i + 1)
#     plt.imshow(example[i].reshape(28, 28), cmap="gray")
#     plt.axis("off")

#     plt.subplot(2, 5, i + 6)
#     plt.imshow(reconstructed_example[i].reshape(28, 28), cmap="gray")
#     plt.axis("off")

# plt.show()