In [None]:
# ==============================
# 🧩 STEP 1 — Imports
# ==============================
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
from torchsummary import summary  # For model visualization
from torchviz import make_dot       # Optional: graphical architecture view

# ==============================
# 🧩 STEP 2 — Dataset Loader
# ==============================
# Folder should contain subfolders of images, or all .jpg/.png images
data_dir = "/path/to/your/image/folder"  # Change this!

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# ==============================
# 🧩 STEP 3 — U-Net Definition
# ==============================
class DoubleConv(nn.Module):
    """(Conv2d → ReLU) × 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.down1 = DoubleConv(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = DoubleConv(128, 256)

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv1 = DoubleConv(256, 128)  # concat skip connection from down2
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv2 = DoubleConv(128, 64)   # concat skip connection from down1

        # Output
        self.output = nn.Conv2d(64, 3, 1)

    def forward(self, x):
        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(self.pool1(d1))
        # Bottleneck
        b = self.bottleneck(self.pool2(d2))
        # Decoder with skip connections
        u1 = self.conv1(torch.cat([self.up1(b), d2], dim=1))
        u2 = self.conv2(torch.cat([self.up2(u1), d1], dim=1))
        # Output
        return self.output(u2)


# ==============================
# 🧩 STEP 4 — Model Visualization
# ==============================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNet().to(device)

# Print model summary
summary(model, (3, 256, 256))

# Optional: visualize graphically (torchviz)
x = torch.randn(1, 3, 256, 256).to(device)
y = model(x)
make_dot(y, params=dict(list(model.named_parameters()))).render("unet_graph", format="png")

# ==============================
# 🧩 STEP 5 — Training (Denoising Autoencoder)
# ==============================
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

def add_noise(img, noise_factor=0.2):
    noise = torch.randn_like(img) * noise_factor
    return torch.clamp(img + noise, 0., 1.)

for epoch in range(1):
    for imgs, _ in dataloader:
        imgs = imgs.to(device)
        noisy_imgs = add_noise(imgs)
        preds = model(noisy_imgs)
        loss = criterion(preds, imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}] Loss: {loss.item():.4f}")

# ==============================
# 🧩 STEP 6 — Visualize Results
# ==============================
imgs, _ = next(iter(dataloader))
noisy_imgs = add_noise(imgs)
with torch.no_grad():
    outputs = model(noisy_imgs.to(device)).cpu()

def show_images(img_list, titles):
    fig, axes = plt.subplots(1, len(img_list), figsize=(15,5))
    for i, (img, title) in enumerate(zip(img_list, titles)):
        axes[i].imshow(torch.permute(img[0], (1,2,0)))
        axes[i].set_title(title)
        axes[i].axis('off')
    plt.show()

show_images([imgs, noisy_imgs, outputs], ["Original", "Noisy", "Denoised"])

# ==============================
# 🧩 STEP 7 — Visualize Filters
# ==============================
filters = model.down1.block[0].weight.data.clone().cpu()
plt.figure(figsize=(10, 5))
for i in range(8):
    plt.subplot(2, 4, i+1)
    f = filters[i]
    f = (f - f.min()) / (f.max() - f.min())
    plt.imshow(f.permute(1,2,0))
    plt.axis('off')
plt.suptitle("Example Filters from First Conv Layer")
plt.show()
