<a href="https://colab.research.google.com/github/amrzhd/DiffusionModel/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Image Generating Training

##Install Dependencies

In [None]:
!pip install torch torchvision
!pip install torchsummary



##Import Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

##Define Data Loader

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 32
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                 download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 60634617.81it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


##Define U-Net Architecture


In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(UNetBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.kernel_size = 3
        self.padding = 1
        # Encoder
        self.encoder = nn.Sequential(
            UNetBlock(in_channels, 64),
            UNetBlock(64, 128),
            UNetBlock(128, 256),
            UNetBlock(256, 512),
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            UNetBlock(512, 1024),
            nn.Conv2d(1024, 1024, kernel_size=self.kernel_size, padding=self.padding),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 512, kernel_size=self.kernel_size, padding=self.padding),  # Corrected: Reduce the number of channels here
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            UNetBlock(1024, 512),  # Adjusted: Input channels should be 1024 (512 from bottleneck + 512 from encoder)
            UNetBlock(512, 256),
            UNetBlock(256, 128),
            UNetBlock(128, 64),
        )

        # Output layer
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder[0](x)
        enc2 = self.encoder[1](enc1)
        enc3 = self.encoder[2](enc2)
        enc4 = self.encoder[3](enc3)

        # Bottleneck
        bn = self.bottleneck(enc4)

        # Decoder
        dec1 = self.decoder[0](torch.cat([enc4, bn], dim=1))
        print(dec1.shape)
        dec2 = self.decoder[1](torch.cat([enc3, dec1], dim=1))
        print(dec2.shape)
        dec3 = self.decoder[2](torch.cat([enc2, dec2], dim=1))
        print(dec1.shape)
        dec4 = self.decoder[3](torch.cat([enc1, dec3], dim=1))

        # Output layer
        out = self.out_conv(dec4)
        return out


##Instantiate the model

In [None]:
from torchsummary import summary
in_channels = 3
out_channels = 3
model = UNet(in_channels, out_channels)
# Define the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the selected device
model.to(device)
input_size = (in_channels, 32, 32)

# Move the input size to the device
input_size = input_size.to(device)

# Print the summary
summary(model, input_size=input_size, device=device)


NameError: ignored

##Define Loss Function and *Optimizer*

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

##Training Loop

In [None]:
num_epochs = 20

# for epoch in range(num_epochs):
for batch in train_loader:
        inputs, _ = batch  # Assuming you don't need labels for unsupervised learning
        print(f'Input shape: {inputs.shape}')

        # Forward pass
        outputs = model(inputs)

        # Print shapes for debugging
        print(f'Output shape: {outputs.shape}')

        # loss = criterion(outputs, inputs)

        # # Backward and optimize
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

# print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Input shape: torch.Size([32, 3, 32, 32])


RuntimeError: ignored

##Save the Trained Model

In [None]:
torch.save(model.state_dict(), 'latent_diffusion_model.pth')

##Visualize Model Accuracy

In [None]:
import matplotlib.pyplot as plt

# Load a sample batch from the training data
sample_batch, _ = next(iter(train_loader))

# Generate predictions using the trained model
with torch.no_grad():
    model.eval()
    reconstructed_batch = model(sample_batch)

# Display the original and reconstructed images
fig, axes = plt.subplots(nrows=2, ncols=8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow(sample_batch[i].permute(1, 2, 0).numpy())
    axes[0, i].axis('off')
    axes[1, i].imshow(reconstructed_batch[i].permute(1, 2, 0).numpy())
    axes[1, i].axis('off')

plt.show()