In [3]:
# --- CELL 1: Parallel Generative Training (Fixed for Mac M1) ---
import os
import sys

# 1. MAXIMIZE MEMORY (M1 Optimization)
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'

import torch
import torch.nn as nn
import torch.nn.functional as F 
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from diffusers import UNet2DModel, DDPMScheduler
from tqdm.auto import tqdm
import gc

# 2. IMPORT THE DATASET (Critical for Multiprocessing on Mac)
# We must import this class. Defining it here will cause the AttributeError.
sys.path.append(os.path.abspath('../src'))
from dataset import FullScaleImageDataset

# --- CONFIGURATION ---
IMAGE_ROOT = '../data/sandbox' 
GEN_EPOCHS = 1
LR = 1e-4

# PARALLEL SETTINGS
BATCH_SIZE = 16          # Physical batch size
ACCUMULATION_STEPS = 2   # Effective batch size = 32
NUM_WORKERS = 4          # Parallel loading enabled

# MPS Setup
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print(f"ðŸš€ Training on: Apple M1 GPU (MPS)")
    print(f"âš¡ Parallel Workers: {NUM_WORKERS}")
    torch.mps.empty_cache()
else:
    DEVICE = torch.device("cpu")
    print("Training on: CPU")

# --- MODELS ---
class SemanticEncoder(nn.Module):
    def __init__(self, latent_dim=256): 
        super().__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        original_first = resnet.conv1
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            resnet.conv1.weight[:] = original_first.weight.sum(dim=1, keepdim=True) / 3.0
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.projection = nn.Linear(512, latent_dim)
    
    def forward(self, x):
        x = self.features(x).view(x.size(0), -1)
        return self.projection(x)

print("Initializing Models...")
unet = UNet2DModel(
    sample_size=64, in_channels=1, out_channels=1, layers_per_block=2,
    block_out_channels=(64, 128, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
    class_embed_type="identity"
).to(DEVICE)

encoder = SemanticEncoder(latent_dim=256).to(DEVICE)
scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(list(unet.parameters()) + list(encoder.parameters()), lr=LR)

# --- TRAIN LOOP ---
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Initialize Dataset from the imported class
dataset = FullScaleImageDataset(IMAGE_ROOT, transform=transform)

loader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS, # Now safe to use 4
    pin_memory=True
)

print(f"Starting Training ({GEN_EPOCHS} Epochs)...")

for epoch in range(GEN_EPOCHS):
    unet.train()
    encoder.train()
    total_loss = 0
    optimizer.zero_grad()
    
    progress = tqdm(loader, desc=f"Epoch {epoch+1}/{GEN_EPOCHS}", leave=False)
    
    for i, images in enumerate(progress):
        images = images.to(DEVICE)
        
        # 1. Encode
        z = encoder(images)
        
        # 2. Add Noise
        noise = torch.randn_like(images)
        t = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device=DEVICE).long()
        noisy_images = scheduler.add_noise(images, noise, t)
        
        # 3. Predict
        noise_pred = unet(noisy_images, t, class_labels=z).sample
        
        # 4. Loss & Accumulate
        loss = F.mse_loss(noise_pred, noise)
        loss = loss / ACCUMULATION_STEPS
        loss.backward()
        
        if (i + 1) % ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
            if torch.backends.mps.is_available():
                torch.mps.empty_cache()
        
        current_loss = loss.item() * ACCUMULATION_STEPS
        total_loss += current_loss
        progress.set_postfix({"loss": current_loss})
        
    print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(loader):.4f}")

    if (epoch + 1) % 10 == 0:
        if not os.path.exists('../models'): os.makedirs('../models')
        torch.save(unet.state_dict(), f"../models/diffusion_unet_full_{epoch+1}.pth")
        torch.save(encoder.state_dict(), f"../models/semantic_encoder_full_{epoch+1}.pth")
        gc.collect()

torch.save(unet.state_dict(), "../models/diffusion_unet.pth")
torch.save(encoder.state_dict(), "../models/semantic_encoder.pth")
print("Training Complete.")

ðŸš€ Training on: Apple M1 GPU (MPS)
âš¡ Parallel Workers: 4
Initializing Models...
ðŸš€ Found 9786 images for Full-Scale Training.
Starting Training (1 Epochs)...


Epoch 1/1:   0%|          | 0/612 [00:19<?, ?it/s]

Epoch 1 | Avg Loss: 0.0436
Training Complete.
