In [2]:
%cd /content/auto_encoder

/content/auto_encoder


In [17]:
import os
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm.notebook import tqdm

In [18]:
# learning parameters
epochs = 100
batch_size = 125
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [19]:
train_data = datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

val_data = datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
)

val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
)

In [20]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=512)
        self.enc2 = nn.Linear(in_features=512, out_features=32)

        # decoder
        self.dec1 = nn.Linear(in_features=32, out_features=512)
        self.dec2 = nn.Linear(in_features=512, out_features=784)

    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))

        # decoding
        x = F.relu(self.dec1(x))
        x = torch.sigmoid(self.dec2(x))
        return x

model = Autoencoder().to(device)
print(model)


Autoencoder(
  (enc1): Linear(in_features=784, out_features=512, bias=True)
  (enc2): Linear(in_features=512, out_features=32, bias=True)
  (dec1): Linear(in_features=32, out_features=512, bias=True)
  (dec2): Linear(in_features=512, out_features=784, bias=True)
)


In [21]:
criterion =nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=lr)

In [22]:
def validate(model, dataloader, epoch):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
            data, _ = data
            data = data.to(device)
            data = data.view(data.size(0), -1)
            reconstruction = model(data)
            loss = criterion(reconstruction, data)
            running_loss += loss.item()

            # ✅ Save the last batch input and output of every epoch
            if i == int(len(val_data) / dataloader.batch_size) - 1:
                num_rows = 8
                images_to_show = min(8, data.size(0))  # Handle last batch with fewer samples
                both = torch.cat((
                    data.view(data.size(0), 1, 28, 28)[:images_to_show],
                    reconstruction.view(data.size(0), 1, 28, 28)[:images_to_show]
                ))
                save_image(both.cpu(), f"output{epoch}.png", nrow=num_rows)
                output_img = plt.imread(f"output{epoch}.png")
                plt.imshow(output_img)
                plt.axis('off')
                plt.show()

    val_loss = running_loss / len(dataloader.dataset)
    return val_loss

In [23]:
def fit(model, dataloader):
    print('Training')
    model.train()
    running_loss = 0.0
    counter = 0
    # for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
    for i, data in (enumerate(dataloader)):
        counter += 1
        data, _ = data
        data = data.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, data)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    epoch_loss = running_loss / counter
    return epoch_loss

In [24]:
train_loss = []
val_loss = []

for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, val_loader, epoch)  # ✅ Pass epoch here
    train_loss.append(train_epoch_loss)
    val_loss.append(val_epoch_loss)
    print(f"Train Loss: {train_epoch_loss:.6f}")
    print(f"Val Loss: {val_epoch_loss:.6f}")


Output hidden; open in https://colab.research.google.com to view.