# 🎮 StackGAN Pokemon Generator - Pipeline Completa

**Questo notebook fa tutto automaticamente:**

1. ✅ Clona il repository

2. ✅ Installa le dipendenze

3. ✅ Training Stage-I (64x64)

4. ✅ Training Stage-II (215x215)

5. ✅ Generazione Pokemon da testo

6. ✅ Visualizzazione risultati



**Esegui le celle in ordine e aspetta che finiscano!**

In [None]:
# 🔧 SETUP INIZIALE - Clona repo e installa dipendenze
import os
import sys
import subprocess
import shutil

print("🚀 AVVIO SETUP STACKGAN POKEMON GENERATOR")
print("=" * 50)

# Repository settings
GIT_REPO_URL = "https://github.com/Biobay/DeepLearning/"
BRANCH_NAME = "RIS-1-CORRETTO"

# Clona o aggiorna il repository
if not os.path.exists("scripts"):
    print(f"📥 Clonazione del branch '{BRANCH_NAME}'...")
    TEMP_DIR = "temp_clone"
    
    try:
        subprocess.run(['git', 'clone', '--branch', BRANCH_NAME, GIT_REPO_URL, TEMP_DIR], check=True)
        print("📂 Copia dei file...")
        shutil.copytree(TEMP_DIR, ".", dirs_exist_ok=True, ignore=shutil.ignore_patterns('.git'))
        shutil.rmtree(TEMP_DIR)
        print("✅ Repository clonato con successo!")
    except Exception as e:
        print(f"❌ Errore durante la clonazione: {e}")
        raise
else:
    print("🔄 Repository già presente! Aggiornamento all'ultimo commit...")
    try:
        subprocess.run(['git', 'fetch', 'origin', BRANCH_NAME], check=True)
        subprocess.run(['git', 'checkout', BRANCH_NAME], check=True)
        subprocess.run(['git', 'pull', 'origin', BRANCH_NAME], check=True)
        print("✅ Repository aggiornato all'ultimo commit!")
    except Exception as e:
        print(f"❌ Errore durante l'aggiornamento: {e}")
        raise

# Aggiunge al Python path
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())
    print("✅ Path di sistema configurato")

# Installa dipendenze
print("📦 Installazione dipendenze...")
try:
    result = subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], 
                          check=True, capture_output=True, text=True)
    print("✅ Dipendenze installate con successo!")
except subprocess.CalledProcessError as e:
    print(f"❌ Errore installazione: {e.stderr}")
    raise

print("🎉 SETUP COMPLETATO! Procedi con la cella successiva")
print("=" * 50)

In [None]:
# 🏋️ TRAINING STAGE-I (64x64)
print("🎯 AVVIO TRAINING STAGE-I (64x64)")
print("=" * 40)

try:
    from scripts.train import train
    import src.config as config
    
    print(f"📊 Configurazione:")
    print(f"   - Epoche: {config.EPOCHS}")
    print(f"   - Device: {config.DEVICE}")
    print(f"   - Batch Size: {config.BATCH_SIZE}")
    print(f"   - Learning Rate: {config.LEARNING_RATE}")
    print("")
    
    print("🚀 Inizio addestramento Stage-I...")
    train(config)
    print("")
    print("✅ TRAINING STAGE-I COMPLETATO!")
    print("📁 Checkpoint salvato in: results/checkpoints/generator_s1.pth")
    print("🖼️ Immagini generate in: results/generated_images/")
    
except Exception as e:
    print(f"❌ Errore durante il training Stage-I: {e}")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# 📊 VISUALIZZA RISULTATI STAGE-I
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import glob

print("📊 ANALISI RISULTATI STAGE-I")
print("=" * 35)

# Visualizza loss
try:
    log_file = "results/logs/loss_log.csv"
    if os.path.exists(log_file):
        df = pd.read_csv(log_file)
        df_epoch = df.groupby('epoch').mean()
        
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.plot(df_epoch.index, df_epoch['loss_d'], 'r-', label='Discriminator', linewidth=2)
        plt.plot(df_epoch.index, df_epoch['loss_g'], 'b-', label='Generator', linewidth=2)
        plt.title('Loss Stage-I', fontsize=14, fontweight='bold')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Mostra immagini generate
        plt.subplot(1, 2, 2)
        generated_imgs = glob.glob("results/generated_images/generated_images_epoch_*.png")
        if generated_imgs:
            latest_img = sorted(generated_imgs)[-1]
            img = Image.open(latest_img)
            plt.imshow(img)
            plt.title('Immagini Generate Stage-I', fontsize=14, fontweight='bold')
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"✅ Loss finale Discriminator: {df_epoch['loss_d'].iloc[-1]:.4f}")
        print(f"✅ Loss finale Generator: {df_epoch['loss_g'].iloc[-1]:.4f}")
    else:
        print("⚠️ File di log non trovato")
        
except Exception as e:
    print(f"❌ Errore visualizzazione: {e}")

print("\n🎯 Procedi con Stage-II!")

In [None]:
# 🚀 TRAINING STAGE-II (215x215)

print("🎯 AVVIO TRAINING STAGE-II (64x64 → 215x215)")

print("=" * 45)



try:

    from scripts.train_stage2 import train_stage2

    import src.config as config

    

    # Verifica che Stage-I sia completato

    checkpoint_s1 = "results/checkpoints/generator_s1.pth"

    if not os.path.exists(checkpoint_s1):

        raise FileNotFoundError(f"❌ Checkpoint Stage-I non trovato! Esegui prima Stage-I.")

    

    print(f"✅ Checkpoint Stage-I trovato: {checkpoint_s1}")

    print(f"📊 Configurazione Stage-II:")

    print(f"   - Epoche: {config.EPOCHS_S2}")

    print(f"   - Learning Rate: {config.LEARNING_RATE_S2}")

    print(f"   - Lambda L1: {config.LAMBDA_L1_S2}")

    print(f"   - Output: {config.STAGE2_IMAGE_SIZE}x{config.STAGE2_IMAGE_SIZE}")

    print("")

    

    print("🚀 Inizio addestramento Stage-II...")

    train_stage2(config)

    print("")

    print("✅ TRAINING STAGE-II COMPLETATO!")

    print(f"📁 Checkpoint salvato in: results/checkpoints/stage2/generator_s2.pth")

    print(f"🖼️ Immagini {config.STAGE2_IMAGE_SIZE}x{config.STAGE2_IMAGE_SIZE} generate in: results/generated_images/stage2/")

    

except Exception as e:

    print(f"❌ Errore durante il training Stage-II: {e}")

    import traceback

    traceback.print_exc()

    raise

In [None]:
# 📊 VISUALIZZA RISULTATI STAGE-II

print("📊 ANALISI RISULTATI STAGE-II")

print("=" * 36)



try:

    # Visualizza loss Stage-II

    log_file_s2 = "results/logs/loss_log_stage2.csv"

    if os.path.exists(log_file_s2):

        df_s2 = pd.read_csv(log_file_s2)

        df_epoch_s2 = df_s2.groupby('epoch').mean()

        

        plt.figure(figsize=(15, 5))

        

        # Loss Stage-II

        plt.subplot(1, 3, 1)

        plt.plot(df_epoch_s2.index, df_epoch_s2['loss_d'], 'r-', label='Discriminator', linewidth=2)

        plt.plot(df_epoch_s2.index, df_epoch_s2['loss_g'], 'b-', label='Generator Total', linewidth=2)

        plt.plot(df_epoch_s2.index, df_epoch_s2['loss_g_adv'], 'g--', label='Gen Adversarial', alpha=0.7)

        plt.plot(df_epoch_s2.index, df_epoch_s2['loss_g_l1'], 'orange', linestyle='--', label='Gen L1', alpha=0.7)

        plt.title('Loss Stage-II', fontsize=14, fontweight='bold')

        plt.xlabel('Epoch')

        plt.ylabel('Loss')

        plt.legend()

        plt.grid(True, alpha=0.3)

        

        # Immagini Stage-I (64x64)

        plt.subplot(1, 3, 2)

        stage1_imgs = glob.glob("results/generated_images/stage2/stage1_images_64_epoch_*.png")

        if stage1_imgs:

            latest_s1 = sorted(stage1_imgs)[-1]

            img_s1 = Image.open(latest_s1)

            plt.imshow(img_s1)

            plt.title('Stage-I Output (64x64)', fontsize=14, fontweight='bold')

            plt.axis('off')

        

        # Immagini Stage-II (215x215)

        plt.subplot(1, 3, 3)

        stage2_imgs = glob.glob("results/generated_images/stage2/stage2_images_215_epoch_*.png")

        if stage2_imgs:

            latest_s2 = sorted(stage2_imgs)[-1]

            img_s2 = Image.open(latest_s2)

            plt.imshow(img_s2)

            plt.title('Stage-II Output (215x215)', fontsize=14, fontweight='bold')

            plt.axis('off')

        

        plt.tight_layout()

        plt.show()

        

        print(f"✅ Loss finale Discriminator S2: {df_epoch_s2['loss_d'].iloc[-1]:.4f}")

        print(f"✅ Loss finale Generator S2: {df_epoch_s2['loss_g'].iloc[-1]:.4f}")

        print(f"✅ Loss finale L1 S2: {df_epoch_s2['loss_g_l1'].iloc[-1]:.4f}")

    else:

        print("⚠️ File di log Stage-II non trovato")

        

except Exception as e:

    print(f"❌ Errore visualizzazione Stage-II: {e}")



print("\n🎮 Ora puoi generare Pokemon personalizzati!")

In [None]:
# 🎮 GENERAZIONE POKEMON PERSONALIZZATA

import torch

from src.models.encoder import TextEncoder

from src.models.decoder import GeneratorS1, GeneratorS2

import src.config as config

from torchvision.utils import save_image

import matplotlib.pyplot as plt

import numpy as np



print("🎮 GENERATORE POKEMON PERSONALIZZATO")

print("=" * 40)



# 🔥 CAMBIA QUESTO TESTO PER GENERARE POKEMON DIVERSI! 🔥

PROMPT = "a red fire dragon pokemon with wings and a long tail"

print(f"📝 Prompt: '{PROMPT}'")



try:

    device = torch.device(config.DEVICE)

    

    # Carica i modelli

    print("📦 Caricamento modelli...")

    text_encoder = TextEncoder(model_name=config.ENCODER_MODEL_NAME, fine_tune=False).to(device)

    netG_s1 = GeneratorS1(config=config).to(device)

    netG_s2 = GeneratorS2(config=config).to(device)

    

    # Carica i checkpoint (con controllo di esistenza)

    checkpoint_s1 = "results/checkpoints/generator_s1.pth"

    checkpoint_s2 = "results/checkpoints/stage2/generator_s2.pth"

    

    if os.path.exists(checkpoint_s1):

        netG_s1.load_state_dict(torch.load(checkpoint_s1, map_location=device))

        print(f"✅ Checkpoint Stage-I caricato: {checkpoint_s1}")

    else:

        print(f"⚠️ Checkpoint Stage-I non trovato: {checkpoint_s1}")

        print("   Esegui prima il training Stage-I!")

        raise FileNotFoundError("Checkpoint Stage-I mancante")

    

    if os.path.exists(checkpoint_s2):

        netG_s2.load_state_dict(torch.load(checkpoint_s2, map_location=device))

        print(f"✅ Checkpoint Stage-II caricato: {checkpoint_s2}")

    else:

        print(f"⚠️ Checkpoint Stage-II non trovato: {checkpoint_s2}")

        print("   Esegui prima il training Stage-II!")

        raise FileNotFoundError("Checkpoint Stage-II mancante")

    

    netG_s1.eval()

    netG_s2.eval()

    text_encoder.eval()

    

    print("✅ Modelli caricati!")

    

    # Processa il testo

    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(config.ENCODER_MODEL_NAME)

    

    # Tokenizza

    tokens = tokenizer(PROMPT, padding='max_length', max_length=config.MAX_TEXT_LENGTH, 
                      truncation=True, return_tensors='pt')

    

    input_ids = tokens['input_ids'].to(device)

    attention_mask = tokens['attention_mask'].to(device)

    

    print("🧠 Elaborazione testo...")

    

    with torch.no_grad():

        # Encode text

        cls_embedding, hidden_states = text_encoder(input_ids, attention_mask)

        

        # Generate Stage-I (64x64)

        noise = torch.randn(1, config.Z_DIM, device=device)

        stage1_img, stage1_mu = netG_s1(cls_embedding, hidden_states, noise)

        

        # Generate Stage-II (215x215)

        stage2_img, _ = netG_s2(stage1_img, cls_embedding, stage1_mu)

    

    print("🎨 Generazione completata!")

    

    # Visualizza risultati

    plt.figure(figsize=(12, 4))

    

    # Stage-I

    plt.subplot(1, 2, 1)

    img_s1 = (stage1_img.squeeze().permute(1, 2, 0).cpu() * 0.5 + 0.5).numpy()

    img_s1 = np.clip(img_s1, 0, 1)

    plt.imshow(img_s1)

    plt.title('Stage-I (64x64)', fontsize=14, fontweight='bold')

    plt.axis('off')

    

    # Stage-II

    plt.subplot(1, 2, 2)

    img_s2 = (stage2_img.squeeze().permute(1, 2, 0).cpu() * 0.5 + 0.5).numpy()

    img_s2 = np.clip(img_s2, 0, 1)

    plt.imshow(img_s2)

    plt.title('Stage-II (215x215)', fontsize=14, fontweight='bold')

    plt.axis('off')

    

    plt.suptitle(f"Pokemon Generato: '{PROMPT}'", fontsize=16, fontweight='bold')

    plt.tight_layout()

    plt.show()

    

    # Salva l'immagine

    output_path = "results/my_generated_pokemon.png"

    save_image(stage2_img, output_path, normalize=True)

    print(f"💾 Immagine salvata in: {output_path}")

    

    print("\n🎉 GENERAZIONE COMPLETATA!")

    print("💡 Cambia il PROMPT nella cella sopra per generare Pokemon diversi!")

    

except Exception as e:

    print(f"❌ Errore durante la generazione: {e}")

    import traceback

    traceback.print_exc()

In [None]:
# 🎯 TEST MULTIPLI - Genera diversi Pokemon
import matplotlib.pyplot as plt
import numpy as np

print("🎯 TEST MULTIPLI - GENERAZIONE POKEMON DIVERSI")
print("=" * 50)

# Lista di prompt diversi da testare
test_prompts = [
    "a blue water pokemon with fins",
    "a yellow electric mouse pokemon",
    "a green grass type pokemon with leaves",
    "a purple psychic pokemon with big eyes",
    "a brown ground type pokemon"
]

try:
    # Controlla che i modelli siano già caricati dalla cella precedente
    if 'netG_s1' not in locals() or 'netG_s2' not in locals():
        print("⚠️ Modelli non caricati. Esegui prima la cella di generazione!")
        raise RuntimeError("Modelli non disponibili")
    
    plt.figure(figsize=(20, 8))
    
    for i, prompt in enumerate(test_prompts):
        print(f"🎨 Generando: '{prompt}'")
        
        # Tokenizza
        tokens = tokenizer(prompt, padding='max_length', max_length=config.MAX_SEQ_LEN, 
                          truncation=True, return_tensors='pt')
        
        input_ids = tokens['input_ids'].to(device)
        attention_mask = tokens['attention_mask'].to(device)
        
        with torch.no_grad():
            # Encode text
            cls_embedding, hidden_states = text_encoder(input_ids, attention_mask)
            
            # Generate with different noise
            noise = torch.randn(1, config.Z_DIM, device=device)
            stage1_img, stage1_mu = netG_s1(cls_embedding, hidden_states, noise)
            stage2_img, _ = netG_s2(stage1_img, cls_embedding, stage1_mu)
        
        # Visualizza
        plt.subplot(1, 5, i+1)
        img_display = (stage2_img.squeeze().permute(1, 2, 0).cpu() * 0.5 + 0.5).numpy()
        img_display = np.clip(img_display, 0, 1)
        plt.imshow(img_display)
        plt.title(f"{prompt[:20]}...", fontsize=10)
        plt.axis('off')
        
        # Salva ogni immagine
        save_image(stage2_img, f"results/test_pokemon_{i+1}.png", normalize=True)
    
    plt.suptitle('Pokemon Generati da Descrizioni Diverse', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print("\n✅ TUTTI I TEST COMPLETATI!")
    print("📁 Immagini salvate in: results/test_pokemon_*.png")
    
except Exception as e:
    print(f"❌ Errore durante i test multipli: {e}")
    import traceback
    traceback.print_exc()

# 🎉 CONGRATULAZIONI!

**Hai completato con successo il training di StackGAN!** 🏆

## 📊 Cosa hai ottenuto:
- ✅ **Stage-I trained**: Genera Pokemon 64x64 da testo
- ✅ **Stage-II trained**: Raffina le immagini a 256x256
- ✅ **Pipeline completa**: Text → 64x64 → 256x256
- ✅ **Modelli salvati**: Pronti per il riuso

## 📁 File generati:
- `results/checkpoints/generator_s1.pth` - Modello Stage-I
- `results/checkpoints/stage2/generator_s2.pth` - Modello Stage-II
- `results/generated_images/` - Immagini di training
- `results/logs/` - Log delle loss
- `results/my_generated_pokemon.png` - La tua generazione!

## 🎮 Prossimi passi:
1. **Modifica i prompt** nella cella di generazione
2. **Aumenta le epoche** per migliorare la qualità
3. **Crea un'interfaccia Gradio** per demo web
4. **Experimenta con nuovi dataset**

**Il tuo StackGAN è pronto! 🚀**