# Artificial Synesthesia: Audio-to-Visual Generation with Pix2Pix

## 1. Setup & Config

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import librosa
import random
import os
import cv2
from PIL import Image
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
import plotly.graph_objects as go

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

IMG_SIZE = 256
SAMPLE_RATE = 22050
DURATION = 1.0 
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 256 


## 2. Synthetic Dataset Generator

In [None]:

class SyntheticAudioGenerator:
    def __init__(self, sample_rate=22050, duration=1.0):
        self.sr = sample_rate
        self.duration = duration
        self.n_samples = int(sample_rate * duration)

    def generate_sine(self, freq=None):
        if freq is None:
            freq = random.uniform(200, 1000)
        t = np.linspace(0, self.duration, self.n_samples)
        audio = np.sin(2 * np.pi * freq * t)
        return audio.astype(np.float32), "sine", freq

    def generate_white_noise(self):
        audio = np.random.normal(0, 0.5, self.n_samples)
        return audio.astype(np.float32), "noise", 0

    def generate_chirp(self):
        t = np.linspace(0, self.duration, self.n_samples)
        f0 = random.uniform(100, 400)
        f1 = random.uniform(800, 1500)
        k = (f1 - f0) / self.duration
        audio = np.sin(2 * np.pi * (f0 * t + 0.5 * k * t**2))
        return audio.astype(np.float32), "chirp", (f0, f1)

class VisualArtGenerator:
    def __init__(self, img_size=256):
        self.size = img_size

    def normalized_to_img(self, data):
        if len(data.shape) == 2:
            data = np.stack([data]*3, axis=-1)
        return (data * 255).astype(np.uint8)

    def generate_gradient(self, color_phase=0.0):
        x = np.linspace(0, 1, self.size)
        y = np.linspace(0, 1, self.size)
        X, Y = np.meshgrid(x, y)
        R = Y
        G = np.sin(Y * np.pi + color_phase) * 0.5 + 0.5
        B = 1.0 - Y
        img = np.stack([R, G, B], axis=-1)
        return img.astype(np.float32)

    def generate_noise_texture(self):
        noise = np.random.uniform(0, 1, (self.size, self.size, 3))
        noise = np.clip((noise - 0.5) * 2.0 + 0.5, 0, 1)
        noise[..., 1] *= 0.5 
        return noise.astype(np.float32)

    def generate_structured_pattern(self):
        x = np.linspace(0, 10, self.size)
        y = np.linspace(0, 10, self.size)
        X, Y = np.meshgrid(x, y)
        Z = np.sin(X + Y) * 0.5 + 0.5
        img = np.stack([Z, Z, 1-Z], axis=-1) 
        return img.astype(np.float32)

class SynesthesiaDataset(Dataset):
    def __init__(self, size=1000, img_size=256, sample_rate=22050):
        self.size = size
        self.audio_gen = SyntheticAudioGenerator(sample_rate)
        self.vis_gen = VisualArtGenerator(img_size)
        
        self.mel_transform = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=N_FFT,
            win_length=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=img_size,
            power=2.0
        )

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        choice = random.choice(["sine", "noise", "chirp"])
        if choice == "sine":
            audio, type_, param = self.audio_gen.generate_sine()
            target_img_np = self.vis_gen.generate_gradient(color_phase=random.random())
        elif choice == "noise":
            audio, type_, param = self.audio_gen.generate_white_noise()
            target_img_np = self.vis_gen.generate_noise_texture()
        else:
            audio, type_, param = self.audio_gen.generate_chirp()
            target_img_np = self.vis_gen.generate_structured_pattern()

        audio_tensor = torch.from_numpy(audio).unsqueeze(0)
        spec = self.mel_transform(audio_tensor) 
        spec = torchaudio.transforms.AmplitudeToDB()(spec)
        spec = (spec + 40) / 40 
        spec = torch.clamp(spec, -1, 1)
        spec = torch.nn.functional.interpolate(spec.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)

        target_img = torch.from_numpy(target_img_np).permute(2, 0, 1)
        target_img = (target_img * 2.0) - 1.0

        return spec, target_img


## 3. Model Architecture (U-Net & PatchGAN)

In [None]:

class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super().__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5) 
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels + 3, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)), 
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


## 4. Training Loop

In [None]:

# Hyperparameters
LR = 0.0002
B1 = 0.5
B2 = 0.999
EPOCHS = 10 
BATCH_SIZE = 16 
LAMBDA_PIXEL = 100

# Initialize
generator = UNetGenerator().to(DEVICE)
discriminator = PatchGANDiscriminator().to(DEVICE)
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

criterion_GAN = nn.BCEWithLogitsLoss().to(DEVICE)
criterion_pixelwise = nn.L1Loss().to(DEVICE)

optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(B1, B2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(B1, B2))

# Dataset
dataloader = DataLoader(
    SynesthesiaDataset(size=2000), 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2
)

if torch.cuda.is_available():
    scaler = torch.amp.GradScaler('cuda')
else:
    scaler = None 

# Training Loop
for epoch in range(EPOCHS):
    for i, (spec, target_img) in enumerate(dataloader):
        
        real_a = spec.to(DEVICE)       
        real_b = target_img.to(DEVICE) 

        # Train Generator
        optimizer_G.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                fake_b = generator(real_a)
                pred_fake = discriminator(fake_b, real_a) 
                valid = torch.ones_like(pred_fake)
                loss_GAN = criterion_GAN(pred_fake, valid)
                loss_pixel = criterion_pixelwise(fake_b, real_b)
                loss_G = loss_GAN + LAMBDA_PIXEL * loss_pixel
            scaler.scale(loss_G).backward()
            scaler.step(optimizer_G)
            scaler.update()
        else:
            fake_b = generator(real_a)
            pred_fake = discriminator(fake_b, real_a)
            valid = torch.ones_like(pred_fake)
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_pixel = criterion_pixelwise(fake_b, real_b)
            loss_G = loss_GAN + LAMBDA_PIXEL * loss_pixel
            loss_G.backward()
            optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        if scaler:
            with torch.amp.autocast('cuda'):
                pred_real = discriminator(real_b, real_a) 
                valid = torch.ones_like(pred_real) 
                loss_real = criterion_GAN(pred_real, valid)
                pred_fake = discriminator(fake_b.detach(), real_a) 
                fake = torch.zeros_like(pred_fake)
                loss_fake = criterion_GAN(pred_fake, fake)
                loss_D = 0.5 * (loss_real + loss_fake)
            scaler.scale(loss_D).backward()
            scaler.step(optimizer_D)
            scaler.update()
        else:
             pred_real = discriminator(real_b, real_a)
             valid = torch.ones_like(pred_real)
             loss_real = criterion_GAN(pred_real, valid)
             pred_fake = discriminator(fake_b.detach(), real_a)
             fake = torch.zeros_like(pred_fake)
             loss_fake = criterion_GAN(pred_fake, fake)
             loss_D = 0.5 * (loss_real + loss_fake)
             loss_D.backward()
             optimizer_D.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] [D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}]")


## 5. Video Inference Pipeline (New)

In [None]:

def generate_synesthesia_video(audio_path, output_path="synesthesia.mp4", fps=30):
    print(f"Processing video for: {audio_path}")
    
    # Load audio
    waveform, sr = torchaudio.load(audio_path)
    if sr != SAMPLE_RATE:
        resampler = T.Resample(sr, SAMPLE_RATE)
        waveform = resampler(waveform)
    
    # Define window logic (2-second windows as requested, overlapping)
    # To get smooth video at FPS, we slide the window by step_size
    window_duration = 2.0 
    window_size = int(window_duration * SAMPLE_RATE)
    step_size = int(SAMPLE_RATE / fps)
    
    # MelTransform reusable (needs to handle arbitrary length or fixed? 
    # We will slice waveform first then transform)
    mel_transform = T.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=N_FFT,
        win_length=N_FFT,
        hop_length=HOP_LENGTH,
        n_mels=IMG_SIZE,
        power=2.0
    )
    
    frames = []
    generator.eval()
    
    # Length check
    total_samples = waveform.shape[1]
    
    print("Generating frames...")
    with torch.no_grad():
        for start in range(0, total_samples - window_size, step_size):
            end = start + window_size
            chunk = waveform[:, start:end]
            
            # Spectrogram
            spec = mel_transform(chunk)
            spec = torchaudio.transforms.AmplitudeToDB()(spec)
            spec = (spec + 40) / 40
            spec = torch.clamp(spec, -1, 1)
            
            # Resize logic: The request asked for 2-second windows.
            # Our model was trained on 1-second spectrograms (approx 86 frames width).
            # A 2-second spectrogram will look "thinner" or "wider" depending on resize.
            # We resize strict to (256, 256) which is what the UNet expects.
            spec = torch.nn.functional.interpolate(spec.unsqueeze(0), size=(256, 256), mode='bilinear').squeeze(0)
            
            # Generate
            fake_art = generator(spec.unsqueeze(0).to(DEVICE))
            
            # Convert to Image
            img_np = (fake_art.squeeze().permute(1,2,0).cpu().numpy() * 0.5 + 0.5)
            img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)
            frames.append(img_np)
            
    # Write Video Temp
    if not frames:
        print("Audio too short for video generation.")
        return

    height, width, _ = frames[0].shape
    temp_video = "temp_video.mp4" # using .mp4 here directly usually works with 'mp4v' or 'libx264' if avail
    
    # Option A: Use MoviePy ImageSequenceClip directly (Complete pipeline)
    print("Compiling video...")
    clip = ImageSequenceClip(frames, fps=fps)
    
    # Sync Audio
    # We used a window of 2s, does the frame correspond to the START of the window or CENTER?
    # Usually Center is better for sync. But simpler to just align start.
    # The audio clip should match the duration of the processed frames.
    video_duration = len(frames) / fps
    audio_clip = AudioFileClip(audio_path).subclip(0, video_duration)
    
    final_clip = clip.set_audio(audio_clip)
    final_clip.write_videofile(output_path, codec="libx264", audio_codec="aac")
    print(f"Video saved to {output_path}")

    from IPython.display import Video
    display(Video(output_path, embed=True))

# Test Video (Create dummy long audio)
import scipy.io.wavfile as wav
gen = SyntheticAudioGenerator(duration=5.0)
audio, _, _ = gen.generate_chirp()
wav.write("long_test.wav", SAMPLE_RATE, audio)

generate_synesthesia_video("long_test.wav", "synesthesia_demo.mp4")


## 6. 3D Visualization (New)

In [None]:

def render_3d_synesthesia(generated_img_tensor):
    # generated_img_tensor: Tensor (1, 3, 256, 256) or (3, 256, 256)
    if generated_img_tensor.dim() == 4:
        generated_img_tensor = generated_img_tensor.squeeze(0)
        
    # Convert to Numpy [H, W, 3] 0..1
    img_np = (generated_img_tensor.permute(1,2,0).cpu().detach().numpy() * 0.5 + 0.5)
    
    # Define "Height Map" as Grayscale Intensity
    # R=0.299, G=0.587, B=0.114
    z_data = np.dot(img_np[...,:3], [0.299, 0.587, 0.114])
    
    # For user request "Map original RGB colors", we can try to use surfacecolor
    # Note: Plotly Surface matches surfacecolor (values) to a colorscale.
    # It does NOT easily support direct RGB texture mapping per vertex in Python API without custom work.
    # PROXY: We will use the 'Magma' colorscale which closely resembles our synthetic art style.
    
    fig = go.Figure(data=[go.Surface(z=z_data, colorscale='Magma')])
    
    fig.update_layout(
        title='3D Synesthesia Landscape',
        autosize=False,
        width=800,
        height=800,
        margin=dict(l=65, r=50, b=65, t=90),
        scene=dict(
            xaxis=dict(title='Time'),
            yaxis=dict(title='Freq'),
            zaxis=dict(title='Intensity'),
        )
    )
    
    fig.show()

# Test 3D (using last generated batch item)
# gen.eval() ...
# render_3d_synesthesia(fake_b[0])
print("3D Function Loaded. Call render_3d_synesthesia(tensor) to visualize.")
