In [None]:
!pip install torch
!pip install scikit-image
!pip install tqdm
import time
from torch.utils.data import Dataset, DataLoader, Subset
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# ---------------------------
# Dataset Loader for dSprites (with subset option)
# ---------------------------
class DSpritesDataset(Dataset):
    def __init__(self, npz_file, subset_size=None):
        data = np.load(npz_file, allow_pickle=True, encoding='latin1', mmap_mode='r')
        self.imgs = data['imgs']
        self.latents_values = data['latents_values'][:, 1:].astype(np.float32)  # Drop 'color' and use float32

        if subset_size:
            indices = np.random.choice(len(self.imgs), subset_size, replace=False)
            self.imgs = self.imgs[indices]
            self.latents_values = self.latents_values[indices]

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = self.imgs[idx].astype(np.float32)[np.newaxis, ...]  # Add channel dim
        params = self.latents_values[idx]
        return torch.tensor(image), torch.tensor(params)

In [None]:
# ---------------------------
# Residual Block for U-Net
# ---------------------------
class ResidualDoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualDoubleConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
    
    def forward(self, x):
        residual = self.residual(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)) + residual)
        return x

In [None]:
# ---------------------------
# Conditional U-Net Model
# ---------------------------
class ConditionalUNet(nn.Module):
    def __init__(self, n_channels=6, n_classes=1, bilinear=True):
        super(ConditionalUNet, self).__init__()
        self.inc = ResidualDoubleConv(n_channels, 64)
        self.down1 = ResidualDoubleConv(64, 128)
        self.down2 = ResidualDoubleConv(128, 256)
        self.down3 = ResidualDoubleConv(256, 512)
        self.down4 = ResidualDoubleConv(512, 1024)
        self.up1 = ResidualDoubleConv(1024, 512)
        self.up2 = ResidualDoubleConv(512, 256)
        self.up3 = ResidualDoubleConv(256, 128)
        self.up4 = ResidualDoubleConv(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, dummy_image, params):
        params = params.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 64, 64)
        x = torch.cat([dummy_image, params], dim=1)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        logits = self.outc(x)
        return torch.sigmoid(logits)

In [None]:
# ---------------------------
# GPU Setup and Checkpoints
# ---------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConditionalUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Function to save model state
def save_checkpoint(model, optimizer, epoch, path="checkpoint.pth"):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, path)
# Load checkpoint if exists
try:
    checkpoint = torch.load("checkpoint.pth", map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
except:
    start_epoch = 0

In [43]:
# ---------------------------
# Training Loop with Progress Bar
# ---------------------------
dataset = DSpritesDataset("C:\Desktop\Mubadala Project\dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz", subset_size=1000)  # Smaller subset
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Smaller batch size
num_epochs = 1  # Reduce epochs

for epoch in range(start_epoch, num_epochs):
    start_time = time.time()
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch_idx, (images, params) in progress_bar:
        images, params = images.to(device), params.to(device)
        optimizer.zero_grad()
        outputs = model(images, params)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    end_time = time.time()
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1} completed in {end_time - start_time:.2f} seconds - Avg Loss: {avg_loss:.4f}")
    if (epoch + 1) % 5 == 0:
        save_checkpoint(model, optimizer, epoch+1)



Epoch 1/1:   2%|▏         | 3/125 [01:02<42:18, 20.81s/it, loss=0.0411]


KeyboardInterrupt: 

In [None]:
# ---------------------------
# Testing / Inference Example
# ---------------------------
model.eval()
with torch.no_grad():
    sample_params = torch.tensor([[0, 0.5, 0.1, 0.5, 0.5]]).to(device)
    dummy_image = torch.zeros((1, 1, 64, 64)).to(device)  # Use zeros
    generated = model(dummy_image, sample_params)
    plt.imshow(generated.squeeze().cpu().numpy(), cmap='gray')
    plt.title("Generated dSprites Image")
    plt.axis('off')
    plt.show()




In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Load dataset
data = np.load("C:/Desktop/Mubadala Project/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz", allow_pickle=True)

# Select a random sample
idx = np.random.randint(len(data['imgs']))
real_image = data['imgs'][idx].astype(np.float32)  # Actual image
latent_params = torch.tensor(data['latents_values'][idx, 1:], dtype=torch.float32).unsqueeze(0).to(device)

# Generate image using Conditional U-Net
model.eval()
with torch.no_grad():
    dummy_input = torch.zeros((1, 1, 64, 64)).to(device)  # Zero-input image
    generated_image = model(dummy_input, latent_params).squeeze().cpu().numpy()

# Plot real vs. generated image
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(real_image, cmap="gray")
axes[0].set_title("Actual Image")
axes[0].axis("off")

axes[1].imshow(generated_image, cmap="gray")
axes[1].set_title("Generated Image")
axes[1].axis("off")

plt.show()

