**Lab-2**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ====== Data Preparation ======
transform = transforms.ToTensor()
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)

# ====== Add Noise Function ======
def add_noise(imgs, noise_factor=0.5):
    noisy = imgs + noise_factor * torch.randn_like(imgs)
    noisy = torch.clip(noisy, 0., 1.)
    return noisy

# ====== Denoising Autoencoder Model ======
class DenoisingAutoencoder(nn.Module):
    def _init_(self):
        super()._init_()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()  # output between 0 and 1
        )

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten
        x = self.encoder(x)
        x = self.decoder(x)
        return x.view(-1, 1, 28, 28)  # Reshape back to image

# ====== Initialize Model, Loss, Optimizer ======
model = DenoisingAutoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ====== Training ======
print("Training Denoising Autoencoder...")
for epoch in range(5):
    for imgs, _ in train_loader:
        noisy_imgs = add_noise(imgs)
        outputs = model(noisy_imgs)
        loss = criterion(outputs, imgs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# ====== Visualize Results ======
test_imgs, _ = next(iter(train_loader))
noisy_test = add_noise(test_imgs)

with torch.no_grad():
    reconstructed = model(noisy_test)

# Display Original, Noisy, and Reconstructed Images
for i in range(5):
    plt.figure(figsize=(9,3))
    # Original
    plt.subplot(1, 3, 1)
    plt.imshow(test_imgs[i].squeeze().numpy(), cmap='gray')
    plt.title("Original")
    plt.axis('off')

    # Noisy
    plt.subplot(1, 3, 2)
    plt.imshow(noisy_test[i].squeeze().numpy(), cmap='gray')
    plt.title("Noisy")
    plt.axis('off')

    # Reconstructed
    plt.subplot(1, 3, 3)
    plt.imshow(reconstructed[i].squeeze().numpy(), cmap='gray')
    plt.title("Denoised")
    plt.axis('off')

    plt.show()