In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

In [19]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DiffusionNoiseSchedule:
    def __init__(self, steps=1000):
        self.steps = steps
        self.betas = torch.linspace(1e-4, 0.02, steps).to(device)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
        
    def get_values(self, t):
        return {
            'sqrt_alpha_bar': torch.sqrt(self.alpha_bars[t]),
            'sqrt_one_minus_alpha_bar': torch.sqrt(1. - self.alpha_bars[t])
        }

class DiffusionUNet(nn.Module):
    def __init__(self, dim=2, hidden=128):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden)
        )
        self.main = nn.Sequential(
            nn.Linear(dim + hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, dim)
        )
        
    def forward(self, x, t):
        t_emb = self.time_embed(t.view(-1, 1))
        return self.main(torch.cat([x, t_emb], -1))
    
class VectorFieldMLP(nn.Module):
    def __init__(self, dim, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
            nn.SiLU(),
            nn.Linear(hidden, dim)
        )

    def forward(self, x, t):
        return self.net(torch.cat([x, t.view(-1, 1)], -1))

class DiffusionModel:
    def __init__(self, dim, steps=1000):
        self.net = VectorFieldMLP(dim).to(device)
        self.schedule = DiffusionNoiseSchedule(steps)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        self.steps = steps
        
    def train(self, data, epochs=2000):
        for epoch in range(epochs):
            t = torch.randint(0, self.steps, (len(data),), device=device)
            noise = torch.randn_like(data)
            
            coeffs = self.schedule.get_values(t)
            noisy_data = coeffs['sqrt_alpha_bar'][:, None] * data + \
                       coeffs['sqrt_one_minus_alpha_bar'][:, None] * noise
            
            pred_noise = self.net(noisy_data, t/self.steps)
            
            loss = torch.mean((pred_noise - noise) ** 2)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            if epoch % 100 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
                
    def sample(self, num_samples=1000, steps=100):
        """DDIM Sampling"""
        x = torch.randn(num_samples, 2, device=device)
        step_size = self.steps // steps
        
        with torch.no_grad():
            for t in reversed(range(0, self.steps, step_size)):
                # Time embedding
                t_tensor = torch.full((num_samples,), t/self.steps, device=device)
                
                # Predict noise
                pred_noise = self.net(x, t_tensor)
                
                # Reverse process step
                alpha_bar = self.schedule.alpha_bars[t]
                x = (x - (1 - alpha_bar).sqrt() * pred_noise) / alpha_bar.sqrt()
                
                if t > 0:
                    x += torch.randn_like(x) * self.schedule.betas[t].sqrt()
                    
        return x.cpu()

# 4. Compositional Diffusion
class DiffusionComposer:
    def __init__(self, models):
        self.models = models
        self.schedule = models[0].schedule  # Shared schedule
        
    def sample(self, num_samples=1000, steps=100, guidance=3.0):
        """Classifier-free guidance composition"""
        x = torch.randn(num_samples, 2, device=device)
        step_size = self.schedule.steps // steps
        
        with torch.no_grad():
            for t in reversed(range(0, self.schedule.steps, step_size)):
                t_tensor = torch.full((num_samples,), t/self.schedule.steps, device=device)
                
                # Get predictions from all models
                noises = [model.net(x, t_tensor) for model in self.models]
                
                # Composition strategy (average + guidance)
                combined_noise = sum(noises)/len(noises) + \
                                guidance * (sum(noises) - len(noises)*noises[0])
                
                # Reverse process step
                alpha_bar = self.schedule.alpha_bars[t]
                x = (x - (1 - alpha_bar).sqrt() * combined_noise) / alpha_bar.sqrt()
                
                if t > 0:
                    x += torch.randn_like(x) * self.schedule.betas[t].sqrt()
                    
        return x.cpu()

# 5. Visualization System
class DiffusionVisualizer:
    def __init__(self, models):
        self.models = models
        self.composer = DiffusionComposer(models)
        
    def animate_sampling(self, num_samples=300, steps=50):
        fig, ax = plt.subplots(1, 3, figsize=(18, 6))
        plots = {
            'model1': ax[0].scatter([], [], s=10, alpha=0.5, c='blue'),
            'model2': ax[1].scatter([], [], s=10, alpha=0.5, c='orange'),
            'composed': ax[2].scatter([], [], s=10, alpha=0.5, c='green')
        }
        
        for a in ax:
            a.set_xlim(-4, 4)
            a.set_ylim(-4, 4)
            a.grid(True)
            
        trajectories = {
            'model1': self._get_trajectory(self.models[0], num_samples, steps),
            'model2': self._get_trajectory(self.models[1], num_samples, steps),
            'composed': self._get_composed_trajectory(num_samples, steps)
        }
        
        def update(frame):
            for key in plots:
                plots[key].set_offsets(trajectories[key][frame])
            return tuple(plots.values())
        
        anim = FuncAnimation(fig, update, frames=range(0, steps+1, 2),
                            interval=50, blit=True)
        plt.close()
        return anim
    
    def _get_trajectory(self, model, num_samples, steps):
        x = torch.randn(num_samples, 2, device=device)
        traj = torch.zeros(steps+1, num_samples, 2)
        traj[0] = x.cpu()
        
        step_size = model.schedule.steps // steps
        with torch.no_grad():
            for i, t in enumerate(reversed(range(0, model.schedule.steps, step_size))):
                t_tensor = torch.full((num_samples,), t/model.schedule.steps, device=device)
                pred_noise = model.net(x, t_tensor)
                
                alpha_bar = model.schedule.alpha_bars[t]
                x = (x - (1 - alpha_bar).sqrt() * pred_noise) / alpha_bar.sqrt()
                if t > 0:
                    x += torch.randn_like(x) * model.schedule.betas[t].sqrt()
                
                traj[i+1] = x.cpu()
        return traj.numpy()
    
    def _get_composed_trajectory(self, num_samples, steps):
        x = torch.randn(num_samples, 2, device=device)
        traj = torch.zeros(steps+1, num_samples, 2)
        traj[0] = x.cpu()
        
        step_size = self.models[0].schedule.steps // steps
        with torch.no_grad():
            for i, t in enumerate(reversed(range(0, self.models[0].schedule.steps, step_size))):
                t_tensor = torch.full((num_samples,), t/self.models[0].schedule.steps, device=device)
                noises = [model.net(x, t_tensor) for model in self.models]
                combined_noise = sum(noises)/len(noises)
                
                alpha_bar = self.models[0].schedule.alpha_bars[t]
                x = (x - (1 - alpha_bar).sqrt() * combined_noise) / alpha_bar.sqrt()
                if t > 0:
                    x += torch.randn_like(x) * self.models[0].schedule.betas[t].sqrt()
                
                traj[i+1] = x.cpu()
        return traj.numpy()

def create_data(centers, num_samples=1000):
    return torch.cat([torch.randn(num_samples//len(centers), 2) + 
                    torch.tensor(c) for c in centers]).to(device)

data1 = create_data([(1.5,1.5), (-1.5,-1.5)])  # Diagonal clusters
data2 = create_data([(1.5,-1.5), (-1.5,1.5)])   # Cross clusters

# Train models
model1 = DiffusionModel(2)
model2 = DiffusionModel(2)

print("Training Model 1...")
model1.train(data1)

print("\nTraining Model 2...")
model2.train(data2)

# Visualize
visualizer = DiffusionVisualizer([model1, model2])
animation = visualizer.animate_sampling()

# Save and display
animation.save('diffusion_composition.gif', writer='pillow')
HTML(animation.to_jshtml())

Training Model 1...
Epoch 0, Loss: 1.0231
Epoch 100, Loss: 0.3498
Epoch 200, Loss: 0.3571
Epoch 300, Loss: 0.3791
Epoch 400, Loss: 0.3361
Epoch 500, Loss: 0.3703
Epoch 600, Loss: 0.3283
Epoch 700, Loss: 0.3662
Epoch 800, Loss: 0.3338
Epoch 900, Loss: 0.3622
Epoch 1000, Loss: 0.3603
Epoch 1100, Loss: 0.3477
Epoch 1200, Loss: 0.3287
Epoch 1300, Loss: 0.3430
Epoch 1400, Loss: 0.3137
Epoch 1500, Loss: 0.3234
Epoch 1600, Loss: 0.3321
Epoch 1700, Loss: 0.3713
Epoch 1800, Loss: 0.3267
Epoch 1900, Loss: 0.3400

Training Model 2...
Epoch 0, Loss: 0.9664
Epoch 100, Loss: 0.4228
Epoch 200, Loss: 0.3218
Epoch 300, Loss: 0.3627
Epoch 400, Loss: 0.3419
Epoch 500, Loss: 0.3586
Epoch 600, Loss: 0.3632
Epoch 700, Loss: 0.3625
Epoch 800, Loss: 0.3310
Epoch 900, Loss: 0.3869
Epoch 1000, Loss: 0.2981
Epoch 1100, Loss: 0.3340
Epoch 1200, Loss: 0.3512
Epoch 1300, Loss: 0.3254
Epoch 1400, Loss: 0.3472
Epoch 1500, Loss: 0.3186
Epoch 1600, Loss: 0.3653
Epoch 1700, Loss: 0.3323
Epoch 1800, Loss: 0.3423
Epoch 19