In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import gaussian_kde
import datetime  

# ---------- Residual Block ----------
class ResBlock(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(h, h),
            nn.LeakyReLU(),
            nn.Linear(h, h),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return x + self.net(x)

# ---------- Generator ----------
class Generator(nn.Module):
    def __init__(self, z=2, h=24, d=2, n_res=4):
        super().__init__()
        self.z_dim = z   

        self.input = nn.Sequential(
            nn.Linear(z, h),
            nn.LeakyReLU()
        )
        self.body = nn.Sequential(*[ResBlock(h) for _ in range(n_res)])
        self.output = nn.Linear(h, d)

    def forward(self, z):
        h = self.input(z)
        h = self.body(h)
        return self.output(h)

# ---------- Discriminator ----------
class Discriminator(nn.Module):
    def __init__(self, d=2, h=24, n_res=3):
        super().__init__()
        self.input = nn.Sequential(
            nn.Linear(d, h),
            nn.LeakyReLU(),
        )
        self.body = nn.Sequential(*[ResBlock(h) for _ in range(n_res)])
        self.output = nn.Sequential(
            nn.Linear(h, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = self.input(x)
        h = self.body(h)
        return self.output(h)

# ---------- Sampling function ----------
@torch.no_grad()
def sample_gan(G, n=2000, device="cpu"):
    G.eval()
    z = torch.randn(n, G.z_dim, device=device)   
    return G(z).cpu().numpy()

# ---------- Training function ----------
def models():    
    # Initialize models
    G = Generator(z=2, h=32, d=2, n_res=2).to("cpu")
    D = Discriminator(d=2, h=42, n_res=4).to("cpu")

    return G, D

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import gaussian_kde
import datetime  

# ---------- Residual Block ----------
class ResBlock(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(h, h),
            nn.ReLU(),
            nn.Linear(h, h),
        )
    def forward(self, x):
        return x + self.net(x)

# ---------- VAE Model ----------
class VAE(nn.Module):
    def __init__(self, d=2, h=256, z=2, n_res=3):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(d, h),
            nn.ReLU(),
            nn.LayerNorm(h),
            *[ResBlock(h) for _ in range(n_res)],
        )
        self.fc_mu = nn.Linear(h, z)
        self.fc_logvar = nn.Linear(h, z)

        # Decoder
        self.dec_in = nn.Sequential(
            nn.Linear(z, h),
            nn.ReLU(),
            nn.LayerNorm(h),
        )
        self.decoder = nn.Sequential(
            *[ResBlock(h) for _ in range(n_res)],
            nn.Linear(h, d)  # output is unbounded
        )

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        h = self.dec_in(z)
        return self.decoder(h), mu, logvar

    @torch.no_grad()
    def sample(self, n=2000, device="cpu"):
        z_ = torch.randn(n, self.fc_mu.out_features, device=device)
        h = self.dec_in(z_)
        return self.decoder(h).cpu().numpy()

# ---------- Loss Function ----------
def vae_loss(x, x_recon, mu, logvar, beta=0.2):
    recon = nn.MSELoss(reduction="mean")(x_recon, x)
    kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kl

# ---------- Training Function ----------
def model():
    vae = VAE(d=2, h=128, z=24, n_res=3).to("cpu")

    return vae


In [3]:
!pip install torchviz



In [6]:
from torchviz import make_dot

# --- Create models ---
G = Generator(z=2, h=32, d=2, n_res=2)
D = Discriminator(d=2, h=42, n_res=4)

# --- Example input ---
z = torch.randn(1, 2)      # latent input
fake = G(z)                 # generator output

# --- Graph for Generator ---
gen_graph = make_dot(fake, params=dict(G.named_parameters()))
gen_graph.render("generator_graph", format="png")

# --- Graph for Discriminator ---
disc_graph = make_dot(D(fake), params=dict(D.named_parameters()))
disc_graph.render("discriminator_graph", format="png")

print("Saved generator_graph.png and discriminator_graph.png")

from torchviz import make_dot

vae = VAE(d=2, h=128, z=24, n_res=3)

# --- Example input ---
x = torch.randn(1, 2)

# Full forward pass
x_recon, mu, logvar = vae(x)

# For VAE, visualize the whole forward output (reconstruction)
vae_graph = make_dot(x_recon, params=dict(vae.named_parameters()))
vae_graph.render("vae_graph", format="png")

print("Saved vae_graph.png")


Saved generator_graph.png and discriminator_graph.png
Saved vae_graph.png


In [12]:
import io
import sys
import datetime
from torchsummary import summary
from PIL import Image, ImageDraw, ImageFont


def save_model_summary(model, input_size, filename=None):
    """
    Saves the torchsummary(model) output as a PNG image.
    Compatible with old versions of torchsummary.
    """

    # ---------- Timestamp ----------
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    if filename is None:
        filename = f"model_summary_{timestamp}.png"
    else:
        filename = filename.replace(".png", f"_{timestamp}.png")

    # ---------- Capture stdout ----------
    backup_stdout = sys.stdout
    sys.stdout = io.StringIO()

    try:
        summary(model, input_size=input_size)
        text = sys.stdout.getvalue()
    finally:
        sys.stdout = backup_stdout

    lines = text.split("\n")

    # ---------- Convert text → PNG ----------
    font = ImageFont.load_default()
    padding = 10
    max_width = max(font.getbbox(line)[2] for line in lines)
    line_height = font.getbbox("A")[3] + 4
    img_height = line_height * len(lines)

    img = Image.new(
        "RGB",
        (max_width + 2 * padding, img_height + 2 * padding),
        color="white"
    )
    draw = ImageDraw.Draw(img)

    y = padding
    for line in lines:
        draw.text((padding, y), line, fill="black", font=font)
        y += line_height

    img.save(filename)
    print(f"Saved model summary → {filename}")


vae = VAE(d=2, h=128, z=24, n_res=3)
save_model_summary(vae, input_size=(2,), filename="vae_summary.png")

G = Generator(z=2, h=32, d=2, n_res=2)
save_model_summary(G, input_size=(2,), filename="generator_summary.png")

D = Discriminator(d=2, h=42, n_res=4)
save_model_summary(D, input_size=(2,), filename="discriminator_summary.png")


Saved model summary → vae_summary_20251204_220244.png
Saved model summary → generator_summary_20251204_220244.png
Saved model summary → discriminator_summary_20251204_220244.png


In [13]:
import io
import sys
import datetime
from torchsummary import summary
from PIL import Image, ImageDraw, ImageFont


def get_summary_text(model, input_size):
    """Return string output of torchsummary(model). Compatible with old versions."""
    backup = sys.stdout
    sys.stdout = io.StringIO()

    try:
        summary(model, input_size=input_size)
        text = sys.stdout.getvalue()
    finally:
        sys.stdout = backup

    return text


def save_gan_summary(generator, discriminator, input_size_G, input_size_D, filename="gan_summary.png"):
    """
    Saves Generator + Discriminator torchsummary outputs in ONE combined PNG image.
    """

    # Timestamp for filename uniqueness
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = filename.replace(".png", f"_{timestamp}.png")

    # ---------- Get summaries ----------
    gen_text = get_summary_text(generator, input_size_G)
    disc_text = get_summary_text(discriminator, input_size_D)

    full_text = (
        "===== GENERATOR =====\n" + gen_text +
        "\n\n===== DISCRIMINATOR =====\n" + disc_text
    )
    lines = full_text.split("\n")

    # ---------- Render summary as PNG ----------
    font = ImageFont.load_default()
    padding = 10
    max_width = max(font.getbbox(line)[2] for line in lines)
    line_height = font.getbbox("A")[3] + 4
    img_height = line_height * len(lines)

    img = Image.new(
        "RGB",
        (max_width + 2 * padding, img_height + 2 * padding),
        color="white"
    )
    draw = ImageDraw.Draw(img)

    y = padding
    for line in lines:
        draw.text((padding, y), line, fill="black", font=font)
        y += line_height

    img.save(filename)
    print(f"Saved combined GAN summary → {filename}")


In [14]:
G = Generator(z=2, h=32, d=2, n_res=2)
D = Discriminator(d=2, h=42, n_res=4)

save_gan_summary(
    generator=G,
    discriminator=D,
    input_size_G=(2,),
    input_size_D=(2,),
    filename="gan_summary.png"
)


Saved combined GAN summary → gan_summary_20251204_220429.png
