# Artificial Synesthesia Phase 2 (Styled): Professional Generative Art

> **Optional**: Upload a file named `style.jpg` (texture/canvas/paper) to the root directory or Google Drive for custom domain injection.

## 1. Setup & Config

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import librosa
import random
import os
import cv2
from PIL import Image
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip
import plotly.graph_objects as go
from scipy.spatial import Voronoi, voronoi_plot_2d
from google.colab import drive
import shutil

# Mount Drive for Checkpoints
try:
    drive.mount('/content/drive')
    CHECKPOINT_DIR = "/content/drive/MyDrive/Synesthesia_Checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
except:
    CHECKPOINT_DIR = "./checkpoints"
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

IMG_SIZE = 256
SAMPLE_RATE = 22050
DURATION = 1.0 
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 256
BATCH_SIZE = 32 if torch.cuda.get_device_properties(0).total_memory > 14e9 else 16
LR = 0.0001
LAMBDA_L1 = 150
EPOCHS = 200
START_EPOCH = 0


## 2. Unified Normalization

In [None]:

class SpectrogramNormalizer:
    @staticmethod
    def transform(waveform):
        mel_transform = T.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            win_length=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=IMG_SIZE,
            power=2.0
        ).to(waveform.device)
        spec = mel_transform(waveform)
        spec = torchaudio.transforms.AmplitudeToDB()(spec)
        spec = (spec + 40) / 40
        spec = torch.clamp(spec, -1, 1)
        if spec.dim() == 2:
            spec = spec.unsqueeze(0)
        spec = torch.nn.functional.interpolate(spec.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
        return spec
normalizer = SpectrogramNormalizer()


## 3. Style-Injected Dataset

In [None]:

# --- STYLE INJECTOR ---
class StyleInjector:
    def __init__(self, size=256, texture_path="texture.jpg"):
        self.size = size
        self.texture = self.load_or_generate_texture(texture_path)
    
    def load_or_generate_texture(self, path):
        if os.path.exists(path):
            print(f"Loading custom style texture from {path}")
            try:
                img = Image.open(path).convert('RGB').resize((self.size, self.size))
                # Normalize 0-1
                return np.array(img).astype(np.float32) / 255.0
            except Exception as e:
                print(f"Error loading texture: {e}")
        
        print("Using Procedural Paper Texture fallback.")
        return self.generate_paper_texture()

    def generate_paper_texture(self):
        # Generate grainy paper noise
        noise = np.random.normal(0.95, 0.05, (self.size, self.size, 3))
        # Add some 'fiber' lines
        for _ in range(20):
            x1, y1 = np.random.randint(0, self.size, 2)
            length = np.random.randint(5, 20)
            angle = np.random.uniform(0, 360)
            x2 = int(x1 + length * np.cos(np.radians(angle)))
            y2 = int(y1 + length * np.sin(np.radians(angle)))
            cv2.line(noise, (x1, y1), (x2, y2), (0.8, 0.8, 0.7), 1)
        return np.clip(noise, 0, 1).astype(np.float32)

    def apply_style(self, generated_img):
        # Blend Mode: Multiply for watercolor effect
        # generated_img is 0-1 RGB
        return np.clip(generated_img * self.texture, 0, 1).astype(np.float32)

# --- AUDIO & VISUAL GENS ---
class AdvancedAudioGenerator:
    def __init__(self, sample_rate=22050, duration=1.0):
        self.sr = sample_rate
        self.duration = duration
        self.n_samples = int(sample_rate * duration)

    def generate_sine(self, freq=None): 
        if freq is None: freq = random.uniform(200, 1000)
        t = np.linspace(0, self.duration, self.n_samples)
        audio = np.sin(2 * np.pi * freq * t)
        return audio.astype(np.float32), "sine", freq
        
    def generate_white_noise(self):
        audio = np.random.normal(0, 0.5, self.n_samples)
        return audio.astype(np.float32), "noise", 0
        
    def generate_chirp(self):
        t = np.linspace(0, self.duration, self.n_samples)
        f0 = random.uniform(100, 400); f1 = random.uniform(800, 1500)
        k = (f1 - f0) / self.duration
        audio = np.sin(2 * np.pi * (f0 * t + 0.5 * k * t**2))
        return audio.astype(np.float32), "chirp", (f0, f1)

    def generate_fm(self):
        t = np.linspace(0, self.duration, self.n_samples)
        cf = random.uniform(200, 800); mf = random.uniform(10, 100); mi = random.uniform(1, 10)
        audio = np.sin(2 * np.pi * cf * t + mi * np.sin(2 * np.pi * mf * t))
        return audio.astype(np.float32), "fm", (cf, mf)

    def generate_percussive(self):
        t = np.linspace(0, self.duration, self.n_samples)
        noise = np.random.normal(0, 0.8, self.n_samples)
        decay = np.exp(-10 * t) 
        audio = noise * decay
        return audio.astype(np.float32), "percussive", 0

class AdvancedVisualGenerator:
    def __init__(self, img_size=256):
        self.size = img_size

    def generate_gradient(self, color_phase=0.0):
        x = np.linspace(0, 1, self.size); y = np.linspace(0, 1, self.size)
        X, Y = np.meshgrid(x, y)
        R = Y; G = np.sin(Y * np.pi + color_phase) * 0.5 + 0.5; B = 1.0 - Y
        return np.stack([R, G, B], axis=-1).astype(np.float32)

    def generate_noise_texture(self):
        noise = np.random.uniform(0, 1, (self.size, self.size, 3))
        noise = np.clip((noise - 0.5) * 2.5 + 0.5, 0, 1) 
        return noise.astype(np.float32)

    def generate_structured_pattern(self):
        x = np.linspace(0, 10, self.size); y = np.linspace(0, 10, self.size)
        X, Y = np.meshgrid(x, y)
        Z = np.sin(X + Y) * 0.5 + 0.5
        return np.stack([Z, Z, 1-Z], axis=-1).astype(np.float32)

    def generate_voronoi(self):
        n_points = 20
        points = np.random.rand(n_points, 2) * self.size
        x = np.arange(self.size); y = np.arange(self.size)
        X, Y = np.meshgrid(x, y)
        img = np.zeros((self.size, self.size, 3))
        for px, py in points:
            dist = np.sqrt((X - px)**2 + (Y - py)**2)
            img[:, :, 0] += np.exp(-dist * 0.05)
            img[:, :, 1] += np.exp(-dist * 0.03) * np.sin(px)
            img[:, :, 2] += np.exp(-dist * 0.04)
        return np.clip(img, 0, 1).astype(np.float32)

    def generate_fractal_percussive(self):
        x = np.linspace(-1, 1, self.size); y = np.linspace(-1, 1, self.size)
        X, Y = np.meshgrid(x, y)
        R = np.sqrt(X**2 + Y**2); A = np.arctan2(Y, X)
        val = np.sin(A * 20) * 0.5 + 0.5
        val *= np.exp(-R * 2)
        img = np.stack([val, 1-val, np.random.rand(*val.shape)*val], axis=-1)
        return img.astype(np.float32)

class SynesthesiaDataset(Dataset):
    def __init__(self, size=2000, img_size=256, sample_rate=22050, style_file="style.jpg"):
        self.size = size
        self.audio_gen = AdvancedAudioGenerator(sample_rate)
        self.vis_gen = AdvancedVisualGenerator(img_size)
        self.style_injector = StyleInjector(img_size, style_file)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        choice = random.choice(["sine", "noise", "chirp", "fm", "percussive"])
        
        if choice == "sine":
            audio, _, _ = self.audio_gen.generate_sine()
            raw_img = self.vis_gen.generate_gradient(random.random())
        elif choice == "noise":
            audio, _, _ = self.audio_gen.generate_white_noise()
            raw_img = self.vis_gen.generate_noise_texture()
        elif choice == "chirp":
            audio, _, _ = self.audio_gen.generate_chirp()
            raw_img = self.vis_gen.generate_structured_pattern()
        elif choice == "fm":
            audio, _, _ = self.audio_gen.generate_fm()
            raw_img = self.vis_gen.generate_voronoi()
        else: 
            audio, _, _ = self.audio_gen.generate_percussive()
            raw_img = self.vis_gen.generate_fractal_percussive()

        # APPLY STYLE
        styled_img_np = self.style_injector.apply_style(raw_img)

        # Process
        audio_tensor = torch.from_numpy(audio).unsqueeze(0)
        spec = normalizer.transform(audio_tensor)

        target_img = torch.from_numpy(styled_img_np).permute(2, 0, 1)
        target_img = (target_img * 2.0) - 1.0 # 0..1 to -1..1

        return spec, target_img


## 4. Pix2Pix Model

In [None]:

class UNetDown(nn.Module):
    def __init__(self, in_c, out_c, norm=True, drop=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)]
        if norm: layers.append(nn.InstanceNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        if drop: layers.append(nn.Dropout(drop))
        self.model = nn.Sequential(*layers)
    def forward(self, x): return self.model(x)

class UNetUp(nn.Module):
    def __init__(self, in_c, out_c, drop=0.0):
        super().__init__()
        layers = [nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False), nn.InstanceNorm2d(out_c), nn.ReLU(True)]
        if drop: layers.append(nn.Dropout(drop))
        self.model = nn.Sequential(*layers)
    def forward(self, x, skip):
        x = self.model(x)
        # Check shapes just in case (debug safety)
        if x.shape != skip.shape:
             # This can happen if padding was needed in Down path. 
             # Resize x to match skip
             x = torch.nn.functional.interpolate(x, size=skip.shape[2:], mode='bilinear')
        return torch.cat((x, skip), 1)

class UNetGenerator(nn.Module):
    def __init__(self, in_c=1, out_c=3):
        super().__init__()
        self.d1=UNetDown(in_c,64,norm=False); self.d2=UNetDown(64,128); self.d3=UNetDown(128,256)
        self.d4=UNetDown(256,512,drop=0.5); self.d5=UNetDown(512,512,drop=0.5); self.d6=UNetDown(512,512,drop=0.5)
        self.d7=UNetDown(512,512,drop=0.5); self.d8=UNetDown(512,512,norm=False,drop=0.5)
        self.u1=UNetUp(512,512,drop=0.5); self.u2=UNetUp(1024,512,drop=0.5); self.u3=UNetUp(1024,512,drop=0.5)
        self.u4=UNetUp(1024,512,drop=0.5); self.u5=UNetUp(1024,256); self.u6=UNetUp(512,128); self.u7=UNetUp(256,64)
        self.final=nn.Sequential(nn.Upsample(scale_factor=2), nn.ZeroPad2d((1,0,1,0)), nn.Conv2d(128,out_c,4,padding=1), nn.Tanh())
    def forward(self, x):
        d1=self.d1(x); d2=self.d2(d1); d3=self.d3(d2); d4=self.d4(d3); d5=self.d5(d4); d6=self.d6(d5); d7=self.d7(d6); d8=self.d8(d7)
        u1=self.u1(d8,d7); u2=self.u2(u1,d6); u3=self.u3(u2,d5); u4=self.u4(u3,d4); u5=self.u5(u4,d3); u6=self.u6(u5,d2); u7=self.u7(u6,d1)
        return self.final(u7)

class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_c=1):
        super().__init__()
        def block(i, o, n=True):
            l = [nn.Conv2d(i, o, 4, 2, 1)]; 
            if n: l.append(nn.InstanceNorm2d(o)); 
            l.append(nn.LeakyReLU(0.2, True)); return l
        self.model = nn.Sequential(*block(in_c+3,64,False), *block(64,128), *block(128,256), *block(256,512), 
                                   nn.ZeroPad2d((1,0,1,0)), nn.Conv2d(512,1,4,padding=1,bias=False))
    def forward(self, a, b): return self.model(torch.cat((a, b), 1))

def weights_init_normal(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d): torch.nn.init.normal_(m.weight.data, 1.0, 0.02); torch.nn.init.constant_(m.bias.data, 0.0)


## 5. Professional Training (Styled)

In [None]:

def save_checkpoint(epoch, generator, discriminator, optimizer_G, optimizer_D):
    path = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch}.pth")
    torch.save({'epoch': epoch, 'G': generator.state_dict(), 'D': discriminator.state_dict(), 'optG': optimizer_G.state_dict(), 'optD': optimizer_D.state_dict()}, path)
    print(f"Saved: {path}")

def load_checkpoint(generator, discriminator, optimizer_G, optimizer_D):
    files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("checkpoint")], key=lambda x: int(x.split('_')[-1].split('.')[0]))
    if not files: return 0
    cp = torch.load(os.path.join(CHECKPOINT_DIR, files[-1]), map_location=DEVICE)
    generator.load_state_dict(cp['G']); discriminator.load_state_dict(cp['D']); optimizer_G.load_state_dict(cp['optG']); optimizer_D.load_state_dict(cp['optD'])
    print(f"Resumed epoch {cp['epoch']}"); return cp['epoch'] + 1

generator = UNetGenerator().to(DEVICE); discriminator = PatchGANDiscriminator().to(DEVICE)
generator.apply(weights_init_normal); discriminator.apply(weights_init_normal)
criterion_GAN = nn.BCEWithLogitsLoss().to(DEVICE); criterion_L1 = nn.L1Loss().to(DEVICE)
optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
scheduler_G = optim.lr_scheduler.CosineAnnealingLR(optimizer_G, T_max=EPOCHS)
scheduler_D = optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=EPOCHS)

# Attempt to look for 'style.jpg' in the current directory or Drive
style_path = "style.jpg"
if not os.path.exists(style_path) and os.path.exists("/content/drive/MyDrive/style.jpg"):
    style_path = "/content/drive/MyDrive/style.jpg"

print(f"Initializing Dataset with style check on: {style_path}")
dataloader = DataLoader(SynesthesiaDataset(size=4000, style_file=style_path), batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

START_EPOCH = load_checkpoint(generator, discriminator, optimizer_G, optimizer_D)

for epoch in range(START_EPOCH, EPOCHS):
    for i, (spec, real_b) in enumerate(dataloader):
        real_a, real_b = spec.to(DEVICE), real_b.to(DEVICE)
        
        # Train G
        optimizer_G.zero_grad()
        with torch.amp.autocast('cuda') if scaler else torch.no_grad(): # no_grad is dummy for CPU fallthrough in this ternary context, manually handled below
             if scaler:
                fake_b = generator(real_a)
                pred_fake = discriminator(fake_b, real_a)
                loss_G = criterion_GAN(pred_fake, torch.ones_like(pred_fake)) + LAMBDA_L1 * criterion_L1(fake_b, real_b)
             else:
                fake_b = generator(real_a)
                pred_fake = discriminator(fake_b, real_a)
                loss_G = criterion_GAN(pred_fake, torch.ones_like(pred_fake)) + LAMBDA_L1 * criterion_L1(fake_b, real_b)
        
        if scaler: scaler.scale(loss_G).backward(); scaler.step(optimizer_G); scaler.update()
        else: loss_G.backward(); optimizer_G.step()

        # Train D
        optimizer_D.zero_grad()
        with torch.amp.autocast('cuda') if scaler else torch.no_grad():
             if scaler:
                pred_real = discriminator(real_b, real_a)
                pred_fake = discriminator(fake_b.detach(), real_a)
                loss_D = 0.5 * (criterion_GAN(pred_real, torch.ones_like(pred_real)) + criterion_GAN(pred_fake, torch.zeros_like(pred_fake)))
             else:
                pred_real = discriminator(real_b, real_a)
                pred_fake = discriminator(fake_b.detach(), real_a)
                loss_D = 0.5 * (criterion_GAN(pred_real, torch.ones_like(pred_real)) + criterion_GAN(pred_fake, torch.zeros_like(pred_fake)))

        if scaler: scaler.scale(loss_D).backward(); scaler.step(optimizer_D); scaler.update()
        else: loss_D.backward(); optimizer_D.step()

        if i % 100 == 0: print(f"E{epoch} B{i} L_D:{loss_D.item():.3f} L_G:{loss_G.item():.3f}")

    scheduler_G.step(); scheduler_D.step()
    if (epoch+1) % 10 == 0: save_checkpoint(epoch, generator, discriminator, optimizer_G, optimizer_D)
    if (epoch+1) % 5 == 0:
        generator.eval()
        with torch.no_grad():
            f = generator(real_a[:3])
            plt.figure(figsize=(9,9))
            for k in range(3):
                plt.subplot(3,3,k*3+1); plt.imshow(real_a[k].cpu().squeeze(),cmap='magma',origin='lower'); plt.axis('off')
                plt.subplot(3,3,k*3+2); plt.imshow(real_b[k].cpu().permute(1,2,0)*0.5+0.5); plt.axis('off')
                plt.subplot(3,3,k*3+3); plt.imshow(f[k].cpu().permute(1,2,0)*0.5+0.5); plt.axis('off')
            plt.show()
        generator.train()


## 6. Kinetic 3D Experience

In [None]:

def generate_kinetic_video(audio_path, output_path="kinetic_synesthesia.mp4", fps=30):
    print(f"Rendering Kinetic 3D Experience for {audio_path}...")
    waveform, sr = torchaudio.load(audio_path)
    if sr != SAMPLE_RATE: waveform = T.Resample(sr, SAMPLE_RATE)(waveform)
    
    window_samples = int(2.0 * SAMPLE_RATE); step_samples = int(SAMPLE_RATE / fps)
    generator.eval(); temp_frames = []
    
    with torch.no_grad():
        for start in range(0, waveform.shape[1] - window_samples, step_samples):
            chunk = waveform[:, start:start+window_samples]
            spec = normalizer.transform(chunk)
            fake_tensor = generator(spec.unsqueeze(0).to(DEVICE))
            img_np = np.clip((fake_tensor.squeeze().permute(1,2,0).cpu().numpy()*0.5+0.5), 0, 1)
            
            # Simple Matplotlib Render
            h_map = cv2.resize(img_np, (64, 64))
            Z = np.dot(h_map[..., :3], [0.299, 0.587, 0.114])
            X, Y = np.meshgrid(np.linspace(0,1,64), np.linspace(0,1,64))
            
            fig = plt.figure(figsize=(10, 5), dpi=80)
            ax2d = fig.add_subplot(1, 2, 1); ax2d.imshow(img_np); ax2d.axis('off'); ax2d.set_title("Styled Art")
            ax3d = fig.add_subplot(1, 2, 2, projection='3d')
            ax3d.plot_surface(X, Y, Z, cmap='magma', linewidth=0, antialiased=False)
            ax3d.set_zlim(0, 1); ax3d.view_init(elev=45, azim=start/1000); ax3d.axis('off'); ax3d.set_title("Topography")
            
            fig.canvas.draw()
            frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            temp_frames.append(frame); plt.close(fig)
            if len(temp_frames) % 50 == 0: print(f"Rendered {len(temp_frames)} frames")
    
    if temp_frames:
        clip = ImageSequenceClip(temp_frames, fps=fps)
        audio = AudioFileClip(audio_path).subclip(0, len(temp_frames)/fps)
        clip.set_audio(audio).write_videofile(output_path, codec="libx264", audio_codec="aac")
        from IPython.display import Video
        display(Video(output_path, embed=True))

# Test
gen = AdvancedAudioGenerator(duration=5.0)
audio, _, _ = gen.generate_fm()
import scipy.io.wavfile as wav
wav.write("kinetics_test.wav", SAMPLE_RATE, audio)
generate_kinetic_video("kinetics_test.wav")
