# Artificial Synesthesia Phase 2: Professional Generative Art

## 1. Setup & Config

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import librosa
import random
import os
import cv2
from PIL import Image
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
import plotly.graph_objects as go
from scipy.spatial import Voronoi, voronoi_plot_2d
from google.colab import drive
import shutil

# Mount Drive for Checkpoints
try:
    drive.mount('/content/drive')
    CHECKPOINT_DIR = "/content/drive/MyDrive/Synesthesia_Checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoints will be saved to {CHECKPOINT_DIR}")
except:
    print("Google Drive not mounted. Checkpoints will be saved locally.")
    CHECKPOINT_DIR = "./checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

IMG_SIZE = 256
SAMPLE_RATE = 22050
DURATION = 1.0 
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 256 

# Hyperparameters Phase 2
BATCH_SIZE = 32 if torch.cuda.get_device_properties(0).total_memory > 14e9 else 16
LR = 0.0001
LAMBDA_L1 = 150
EPOCHS = 200
START_EPOCH = 0


## 2. Unified Normalization

In [None]:

class SpectrogramNormalizer:
    @staticmethod
    def transform(waveform):
        # Waveform -> MelSpectrogram
        mel_transform = T.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            win_length=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=IMG_SIZE,
            power=2.0
        ).to(waveform.device)
        
        spec = mel_transform(waveform)
        spec = torchaudio.transforms.AmplitudeToDB()(spec)
        
        # Strict Normalization: -80dB to 0dB -> [-1, 1]
        # (x + 40) / 40 -> range -1 to 1 approximately
        spec = (spec + 40) / 40
        spec = torch.clamp(spec, -1, 1)
        
        # Resize to fixed input size (256, 256)
        # Note: We interpolate to ensure strict 256x256 input for UNet
        if spec.dim() == 2:
            spec = spec.unsqueeze(0)
        
        spec = torch.nn.functional.interpolate(spec.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
        return spec

normalizer = SpectrogramNormalizer()


## 3. Advanced Synthetic Data & Generator (Phase 2)

In [None]:

class AdvancedAudioGenerator(SyntheticAudioGenerator): # Inherits basic methods
    def __init__(self, sample_rate=22050, duration=1.0):
        self.sr = sample_rate
        self.duration = duration
        self.n_samples = int(sample_rate * duration)

    def generate_sine(self, freq=None): # Re-implementing to be self-contained if class not found
        if freq is None:
            freq = random.uniform(200, 1000)
        t = np.linspace(0, self.duration, self.n_samples)
        audio = np.sin(2 * np.pi * freq * t)
        return audio.astype(np.float32), "sine", freq
        
    def generate_white_noise(self):
        audio = np.random.normal(0, 0.5, self.n_samples)
        return audio.astype(np.float32), "noise", 0
        
    def generate_chirp(self):
        t = np.linspace(0, self.duration, self.n_samples)
        f0 = random.uniform(100, 400)
        f1 = random.uniform(800, 1500)
        k = (f1 - f0) / self.duration
        audio = np.sin(2 * np.pi * (f0 * t + 0.5 * k * t**2))
        return audio.astype(np.float32), "chirp", (f0, f1)

    def generate_fm(self):
        t = np.linspace(0, self.duration, self.n_samples)
        carrier_freq = random.uniform(200, 800)
        mod_freq = random.uniform(10, 100)
        mod_index = random.uniform(1, 10)
        
        audio = np.sin(2 * np.pi * carrier_freq * t + mod_index * np.sin(2 * np.pi * mod_freq * t))
        return audio.astype(np.float32), "fm", (carrier_freq, mod_freq)

    def generate_percussive(self):
        # Decay envelope noise
        t = np.linspace(0, self.duration, self.n_samples)
        noise = np.random.normal(0, 0.8, self.n_samples)
        decay = np.exp(-10 * t) # fast decay
        audio = noise * decay
        return audio.astype(np.float32), "percussive", 0

class AdvancedVisualGenerator:
    def __init__(self, img_size=256):
        self.size = img_size

    def generate_gradient(self, color_phase=0.0):
        x = np.linspace(0, 1, self.size)
        y = np.linspace(0, 1, self.size)
        X, Y = np.meshgrid(x, y)
        R = Y
        G = np.sin(Y * np.pi + color_phase) * 0.5 + 0.5
        B = 1.0 - Y
        return np.stack([R, G, B], axis=-1).astype(np.float32)

    def generate_noise_texture(self):
        noise = np.random.uniform(0, 1, (self.size, self.size, 3))
        noise = np.clip((noise - 0.5) * 2.5 + 0.5, 0, 1) # High contrast
        return noise.astype(np.float32)

    def generate_structured_pattern(self):
        x = np.linspace(0, 10, self.size)
        y = np.linspace(0, 10, self.size)
        X, Y = np.meshgrid(x, y)
        Z = np.sin(X + Y) * 0.5 + 0.5
        return np.stack([Z, Z, 1-Z], axis=-1).astype(np.float32)

    def generate_voronoi(self):
        # Simulate Complex Timbre (FM)
        # Create random points
        n_points = 20
        points = np.random.rand(n_points, 2) * self.size
        
        # Grid
        x = np.arange(self.size)
        y = np.arange(self.size)
        X, Y = np.meshgrid(x, y)
        
        # Calculate distance to nearest point (vectorized is heavy, looping for simplicity/speed trade)
        # We'll use a fast approximation: radial blobs
        img = np.zeros((self.size, self.size, 3))
        for px, py in points:
            dist = np.sqrt((X - px)**2 + (Y - py)**2)
            # Add blobs
            img[:, :, 0] += np.exp(-dist * 0.05)
            img[:, :, 1] += np.exp(-dist * 0.03) * np.sin(px)
            img[:, :, 2] += np.exp(-dist * 0.04)
            
        img = np.clip(img, 0, 1)
        return img.astype(np.float32)

    def generate_fractal_percussive(self):
        # Sharp bursts -> Starburst pattern
        x = np.linspace(-1, 1, self.size)
        y = np.linspace(-1, 1, self.size)
        X, Y = np.meshgrid(x, y)
        R = np.sqrt(X**2 + Y**2)
        A = np.arctan2(Y, X)
        
        # Starburst
        val = np.sin(A * 20) * 0.5 + 0.5
        val *= np.exp(-R * 2)
        
        img = np.stack([val, 1-val, np.random.rand(*val.shape)*val], axis=-1)
        return img.astype(np.float32)

class SynesthesiaDataset(Dataset):
    def __init__(self, size=2000, img_size=256, sample_rate=22050):
        self.size = size
        self.audio_gen = AdvancedAudioGenerator(sample_rate)
        self.vis_gen = AdvancedVisualGenerator(img_size)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        choice = random.choice(["sine", "noise", "chirp", "fm", "percussive"])
        
        if choice == "sine":
            audio, _, _ = self.audio_gen.generate_sine()
            target_img_np = self.vis_gen.generate_gradient(random.random())
        elif choice == "noise":
            audio, _, _ = self.audio_gen.generate_white_noise()
            target_img_np = self.vis_gen.generate_noise_texture()
        elif choice == "chirp":
            audio, _, _ = self.audio_gen.generate_chirp()
            target_img_np = self.vis_gen.generate_structured_pattern()
        elif choice == "fm":
            audio, _, _ = self.audio_gen.generate_fm()
            target_img_np = self.vis_gen.generate_voronoi()
        else: # percussive
            audio, _, _ = self.audio_gen.generate_percussive()
            target_img_np = self.vis_gen.generate_fractal_percussive()

        # Audio -> Spectrogram -> Unified Normalize
        audio_tensor = torch.from_numpy(audio).unsqueeze(0)
        spec = normalizer.transform(audio_tensor)

        # Image -> Tensor -> Normalize [-1, 1]
        target_img = torch.from_numpy(target_img_np).permute(2, 0, 1)
        target_img = (target_img * 2.0) - 1.0

        return spec, target_img


## 4. Pix2Pix Model

In [None]:

# [Same UNet and PatchGAN code as before to ensure self-contained notebook]
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)
    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super().__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)
        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)
        self.final = nn.Sequential(nn.Upsample(scale_factor=2), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(128, out_channels, 4, padding=1), nn.Tanh())
    def forward(self, x):
        d1 = self.down1(x); d2 = self.down2(d1); d3 = self.down3(d2); d4 = self.down4(d3)
        d5 = self.down5(d4); d6 = self.down6(d5); d7 = self.down7(d6); d8 = self.down8(d7)
        u1 = self.up1(d8, d7); u2 = self.up2(u1, d6); u3 = self.up3(u2, d5); u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3); u6 = self.up6(u5, d2); u7 = self.up7(u6, d1)
        return self.final(u7)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        def block(in_f, out_f, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 4, stride=2, padding=1)]
            if norm: layers.append(nn.InstanceNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(in_channels + 3, 64, norm=False), *block(64, 128),
            *block(128, 256), *block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )
    def forward(self, img_A, img_B):
        return self.model(torch.cat((img_A, img_B), 1))

def weights_init_normal(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


## 5. Professional Training Loop (200 Epochs + Checkpoints)

In [None]:

# Helpers
def save_checkpoint(epoch, generator, discriminator, optimizer_G, optimizer_D):
    path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch}.pth")
    torch.save({
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
    }, path)
    print(f"Checkpoint saved: {path}")

def load_checkpoint(generator, discriminator, optimizer_G, optimizer_D):
    # Find latest
    files = [f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("checkpoint")]
    if not files:
        return 0
    files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    latest = files[-1]
    
    checkpoint = torch.load(os.path.join(CHECKPOINT_DIR, latest), map_location=DEVICE)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    print(f"Resumed from epoch {checkpoint['epoch']}")
    return checkpoint['epoch'] + 1

# Setup
generator = UNetGenerator().to(DEVICE)
discriminator = PatchGANDiscriminator().to(DEVICE)
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion_GAN = nn.BCEWithLogitsLoss().to(DEVICE)
criterion_pixelwise = nn.L1Loss().to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))

scheduler_G = optim.lr_scheduler.CosineAnnealingLR(optimizer_G, T_max=EPOCHS)
scheduler_D = optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=EPOCHS)

dataloader = DataLoader(SynesthesiaDataset(size=4000), batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

if torch.cuda.is_available():
    scaler = torch.amp.GradScaler('cuda')
else:
    scaler = None

# Resume
START_EPOCH = load_checkpoint(generator, discriminator, optimizer_G, optimizer_D)

# Training Loop
for epoch in range(START_EPOCH, EPOCHS):
    for i, (spec, target_img) in enumerate(dataloader):
        real_a = spec.to(DEVICE)
        real_b = target_img.to(DEVICE)

        # Train G
        optimizer_G.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                fake_b = generator(real_a)
                pred_fake = discriminator(fake_b, real_a)
                valid = torch.ones_like(pred_fake)
                loss_GAN = criterion_GAN(pred_fake, valid)
                loss_pixel = criterion_pixelwise(fake_b, real_b)
                loss_G = loss_GAN + LAMBDA_L1 * loss_pixel
            scaler.scale(loss_G).backward()
            scaler.step(optimizer_G)
            scaler.update()
        else:
            fake_b = generator(real_a)
            pred_fake = discriminator(fake_b, real_a)
            valid = torch.ones_like(pred_fake)
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_pixel = criterion_pixelwise(fake_b, real_b)
            loss_G = loss_GAN + LAMBDA_L1 * loss_pixel
            loss_G.backward()
            optimizer_G.step()

        # Train D
        optimizer_D.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                pred_real = discriminator(real_b, real_a)
                valid = torch.ones_like(pred_real)
                loss_real = criterion_GAN(pred_real, valid)
                pred_fake = discriminator(fake_b.detach(), real_a)
                fake = torch.zeros_like(pred_fake)
                loss_fake = criterion_GAN(pred_fake, fake)
                loss_D = 0.5 * (loss_real + loss_fake)
            scaler.scale(loss_D).backward()
            scaler.step(optimizer_D)
            scaler.update()
        else:
            pred_real = discriminator(real_b, real_a)
            loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
            pred_fake = discriminator(fake_b.detach(), real_a)
            loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
            loss_D = 0.5 * (loss_real + loss_fake)
            loss_D.backward()
            optimizer_D.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}] [Batch {i}] [Loss D: {loss_D.item():.4f}] [Loss G: {loss_G.item():.4f}]")

    scheduler_G.step()
    scheduler_D.step()

    # Save and Show
    if (epoch + 1) % 10 == 0:
        save_checkpoint(epoch, generator, discriminator, optimizer_G, optimizer_D)
        
    # Visualize 3x3 Grid
    if (epoch + 1) % 5 == 0:
        generator.eval()
        with torch.no_grad():
            fakes = generator(real_a[:3])
            # Show simplified
            plt.figure(figsize=(9, 9))
            for k in range(3):
                plt.subplot(3, 3, k*3+1)
                plt.imshow(real_a[k].cpu().squeeze(), cmap='magma', origin='lower'); plt.axis('off')
                plt.subplot(3, 3, k*3+2)
                plt.imshow(real_b[k].cpu().permute(1,2,0) * 0.5 + 0.5); plt.axis('off')
                plt.subplot(3, 3, k*3+3)
                plt.imshow(fakes[k].cpu().permute(1,2,0) * 0.5 + 0.5); plt.axis('off')
            plt.show()
        generator.train()


## 6. Kinetic 3D Experience (Final Demo)

In [None]:

def generate_kinetic_video(audio_path, output_path="kinetic_synesthesia.mp4", fps=30):
    print(f"Rendering Kinetic 3D Experience for {audio_path}...")
    waveform, sr = torchaudio.load(audio_path)
    if sr != SAMPLE_RATE:
        waveform = T.Resample(sr, SAMPLE_RATE)(waveform)
    
    window_duration = 2.0
    window_samples = int(window_duration * SAMPLE_RATE)
    step_samples = int(SAMPLE_RATE / fps)
    
    generator.eval()
    
    # Pre-setup Matplotlib logic
    # We want a FIG with 2 subplots: Left (Image), Right (3D Wireframe)
    
    temp_frames = []
    
    with torch.no_grad():
        for start in range(0, waveform.shape[1] - window_samples, step_samples):
            chunk = waveform[:, start:start+window_samples]
            
            # 1. Spec & Gen
            spec = normalizer.transform(chunk)
            fake_tensor = generator(spec.unsqueeze(0).to(DEVICE))
            
            # 2. Get Data for Plotting
            img_np = (fake_tensor.squeeze().permute(1,2,0).cpu().numpy() * 0.5 + 0.5)
            img_np = np.clip(img_np, 0, 1) # Keep float for plot
            
            # Grayscale Height Map (Low Res for 3D speed)
            # Resize img_np to 64x64 for wireframe speed
            h_map_small = cv2.resize(img_np, (64, 64))
            Z = np.dot(h_map_small[..., :3], [0.299, 0.587, 0.114])
            X, Y = np.meshgrid(np.linspace(0, 1, 64), np.linspace(0, 1, 64))
            
            # 3. Render Frame with Matplotlib
            fig = plt.figure(figsize=(10, 5), dpi=100)
            
            # Left: Art
            ax2d = fig.add_subplot(1, 2, 1)
            ax2d.imshow(img_np)
            ax2d.set_title("Synesthesia Art")
            ax2d.axis('off')
            
            # Right: 3D Pulse
            ax3d = fig.add_subplot(1, 2, 2, projection='3d')
            # Wireframe or Surface
            surf = ax3d.plot_surface(X, Y, Z, cmap='magma', linewidth=0, antialiased=False)
            ax3d.set_zlim(0, 1)
            ax3d.view_init(elev=45, azim=start/1000) # Simple rotation effect
            ax3d.set_title("Kinetic Topography")
            ax3d.axis('off')
            
            # Save to buffer
            fig.canvas.draw()
            frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            temp_frames.append(frame)
            
            plt.close(fig)
            if len(temp_frames) % 50 == 0:
                print(f"Rendered {len(temp_frames)} frames...")
    
    if not temp_frames:
        return
        
    print("Encoding video...")
    clip = ImageSequenceClip(temp_frames, fps=fps)
    audio_clip = AudioFileClip(audio_path).subclip(0, len(temp_frames)/fps)
    final = clip.set_audio(audio_clip)
    final.write_videofile(output_path, codec="libx264", audio_codec="aac")
    
    from IPython.display import Video
    display(Video(output_path, embed=True))

# Create a test file and run
gen = AdvancedAudioGenerator(duration=5.0)
audio, _, _ = gen.generate_fm()
import scipy.io.wavfile as wav
wav.write("kinetics_test.wav", SAMPLE_RATE, audio)

generate_kinetic_video("kinetics_test.wav")
