In [1]:
import os
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid

In [None]:
# Параметры
IMAGE_SIZE = 64
BATCH = 128
LATENT_DIM = 256
NUM_EPOCHS = 80
LR = 2e-4
EMBED_DIM = 128   
BASE_CHANNELS = 64
BETA = 1.0     
SAMPLE_PER_CLASS = 6

In [None]:
DATA_ROOT = Path("data/animals")
OUT_DIR = Path("outputs_cvae")
OUT_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

Device: cuda


In [None]:
# DataLoader
tf = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])  
])

dataset = datasets.ImageFolder(root=str(DATA_ROOT), transform=tf)
if len(dataset) == 0:
    raise RuntimeError(f"No images found in {DATA_ROOT}")

class_names = dataset.classes
NUM_CLASSES = len(class_names)
print(f"Found {len(dataset)} images, {NUM_CLASSES} classes:", class_names)

loader = DataLoader(dataset, batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=(DEVICE.type=="cuda"), drop_last=True)

Found 13474 images, 5 classes: ['cat', 'dog', 'elephant', 'horse', 'lion']


In [None]:
# Утилиты
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

def save_image_grid(tensor, path, nrow=8):
    tensor = tensor.clamp(-1, 1)
    save_image((tensor + 1) / 2, path, nrow=nrow)

In [None]:
# CVAE model (Encoder + Decoder + reparam)
class Encoder(nn.Module):
    def __init__(self, img_channels=3, base=BASE_CHANNELS, embed_dim=EMBED_DIM, n_classes=NUM_CLASSES, latent_dim=LATENT_DIM):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(img_channels, base, 4, 2, 1), nn.BatchNorm2d(base), nn.ReLU(True),
            nn.Conv2d(base, base*2, 4, 2, 1), nn.BatchNorm2d(base*2), nn.ReLU(True),
            nn.Conv2d(base*2, base*4, 4, 2, 1), nn.BatchNorm2d(base*4), nn.ReLU(True),
            nn.Conv2d(base*4, base*8, 4, 2, 1), nn.BatchNorm2d(base*8), nn.ReLU(True)
        )
        feat_dim = base*8*4*4
        self.label_emb = nn.Embedding(n_classes, embed_dim)

        self.fc_mu = nn.Linear(feat_dim + embed_dim, latent_dim)
        self.fc_logvar = nn.Linear(feat_dim + embed_dim, latent_dim)

    def forward(self, x, labels):
        b = x.size(0)
        f = self.conv(x)         
        f = f.view(b, -1)        
        le = self.label_emb(labels)  
        h = torch.cat([f, le], dim=1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, img_channels=3, base=BASE_CHANNELS, embed_dim=EMBED_DIM, n_classes=NUM_CLASSES, latent_dim=LATENT_DIM):
        super().__init__()
        self.label_emb = nn.Embedding(n_classes, embed_dim)
        self.latent_input_dim = latent_dim + embed_dim
        self.fc = nn.Linear(self.latent_input_dim, base*8*4*4)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(base*8, base*4, 4, 2, 1), nn.BatchNorm2d(base*4), nn.ReLU(True),
            nn.ConvTranspose2d(base*4, base*2, 4, 2, 1), nn.BatchNorm2d(base*2), nn.ReLU(True),
            nn.ConvTranspose2d(base*2, base, 4, 2, 1), nn.BatchNorm2d(base), nn.ReLU(True),
            nn.ConvTranspose2d(base, img_channels, 4, 2, 1),
            nn.Tanh()   # outputs in [-1,1]
        )

    def forward(self, z, labels):
        le = self.label_emb(labels)
        z_cond = torch.cat([z, le], dim=1)  
        x = self.fc(z_cond)
        x = x.view(-1, BASE_CHANNELS*8, 4, 4)
        x = self.deconv(x)
        return x

class ConditionalVAE(nn.Module):
    def __init__(self, n_classes=NUM_CLASSES, latent_dim=LATENT_DIM, embed_dim=EMBED_DIM):
        super().__init__()
        self.encoder = Encoder(n_classes=n_classes, latent_dim=latent_dim, embed_dim=embed_dim)
        self.decoder = Decoder(n_classes=n_classes, latent_dim=latent_dim, embed_dim=embed_dim)

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()   # exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, labels):
        mu, logvar = self.encoder(x, labels)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z, labels)
        return recon, mu, logvar

In [None]:
# Инициализация
model = ConditionalVAE(n_classes=NUM_CLASSES, latent_dim=LATENT_DIM, embed_dim=EMBED_DIM).to(DEVICE)
model.apply(weights_init)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, betas=(0.5, 0.999))

# fixed noise/labels для визуализации 
fixed_z = torch.randn(NUM_CLASSES * SAMPLE_PER_CLASS, LATENT_DIM, device=DEVICE)
fixed_labels = torch.tensor([i for i in range(NUM_CLASSES) for _ in range(SAMPLE_PER_CLASS)], device=DEVICE)

In [None]:
# loss функции и train step
def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    # KL divergence
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss, kld

def train_epoch(model, loader, optimizer, device, epoch, beta=BETA):
    model.train()
    running_recon = 0.0
    running_kld = 0.0
    pbar = tqdm(loader, desc=f"CVAE epoch {epoch+1}")
    for imgs, labels in pbar:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(imgs, labels)
        recon_loss, kld = loss_function(recon, imgs, mu, logvar)
        loss = recon_loss + beta * kld
        loss.backward()
        optimizer.step()

        running_recon += float(recon_loss.item())
        running_kld += float(kld.item())
        pbar.set_postfix({'recon': running_recon/ (pbar.n+1), 'kld': running_kld/(pbar.n+1)})
    return running_recon/len(loader), running_kld/len(loader)

In [None]:
# функции сохранения примеров / генерации / реконструкции
def save_reconstructions(model, data_loader, path, n=8):
    model.eval()
    imgs, labels = next(iter(data_loader))
    imgs = imgs[:n].to(DEVICE)
    labels = labels[:n].to(DEVICE)
    with torch.no_grad():
        recon, _, _ = model(imgs, labels)
    cat = torch.cat([imgs.cpu(), recon.cpu()], dim=0)
    save_image_grid(cat, path, nrow=n)

def generate_by_class(model, class_idx, n=8, save_path=None):
    model.eval()
    z = torch.randn(n, LATENT_DIM, device=DEVICE)
    labels = torch.tensor([class_idx]*n, device=DEVICE)
    with torch.no_grad():
        imgs = model.decoder(z, labels).cpu()
    if save_path:
        save_image_grid(imgs, save_path, nrow=n)
    return imgs

def sample_grid_fixed(model, fixed_z, fixed_labels, save_path):
    model.eval()
    with torch.no_grad():
        z = fixed_z.to(DEVICE)
        labels = fixed_labels.to(DEVICE)
        imgs = model.decoder(z, labels).cpu()
    save_image_grid(imgs, save_path, nrow=SAMPLE_PER_CLASS)

In [10]:
# training loop, чекпоинты, примеры
best_loss = 1e9
for epoch in range(NUM_EPOCHS):
    recon_avg, kld_avg = train_epoch(model, loader, optimizer, DEVICE, epoch, beta=BETA)
    recon_path = OUT_DIR / f"recon_epoch{epoch+1}.png"
    save_reconstructions(model, loader, recon_path, n=8)

    samples_path = OUT_DIR / f"samples_epoch{epoch+1}.png"
    sample_grid_fixed(model, fixed_z, fixed_labels, samples_path)

    # чекпоинты
    torch.save({
        'epoch': epoch+1,
        'model_state': model.state_dict(),
        'opt_state': optimizer.state_dict()
    }, OUT_DIR / f"cvae_checkpoint_epoch{epoch+1}.pth")

    print(f"Epoch {epoch+1}: recon={recon_avg:.5f}, kld={kld_avg:.5f}, saved: {recon_path.name}, {samples_path.name}")

# финальное сохранение
torch.save(model.state_dict(), OUT_DIR / "cvae_final.pth")
print("Training finished. Models and examples saved to", OUT_DIR)

CVAE epoch 1: 100%|██████████| 105/105 [00:21<00:00,  4.84it/s, recon=0.196, kld=0.151]


Epoch 1: recon=0.19074, kld=0.14655, saved: recon_epoch1.png, samples_epoch1.png


CVAE epoch 2: 100%|██████████| 105/105 [00:42<00:00,  2.49it/s, recon=0.159, kld=0.0445]


Epoch 2: recon=0.15771, kld=0.04403, saved: recon_epoch2.png, samples_epoch2.png


CVAE epoch 3: 100%|██████████| 105/105 [00:37<00:00,  2.79it/s, recon=0.152, kld=0.0396]


Epoch 3: recon=0.15214, kld=0.03964, saved: recon_epoch3.png, samples_epoch3.png


CVAE epoch 4: 100%|██████████| 105/105 [00:43<00:00,  2.40it/s, recon=0.147, kld=0.0374]


Epoch 4: recon=0.14722, kld=0.03740, saved: recon_epoch4.png, samples_epoch4.png


CVAE epoch 5: 100%|██████████| 105/105 [00:24<00:00,  4.36it/s, recon=0.144, kld=0.0358]


Epoch 5: recon=0.14114, kld=0.03515, saved: recon_epoch5.png, samples_epoch5.png


CVAE epoch 6: 100%|██████████| 105/105 [00:38<00:00,  2.75it/s, recon=0.138, kld=0.0335]


Epoch 6: recon=0.13668, kld=0.03321, saved: recon_epoch6.png, samples_epoch6.png


CVAE epoch 7: 100%|██████████| 105/105 [00:21<00:00,  4.92it/s, recon=0.133, kld=0.0322]


Epoch 7: recon=0.13345, kld=0.03218, saved: recon_epoch7.png, samples_epoch7.png


CVAE epoch 8: 100%|██████████| 105/105 [00:22<00:00,  4.75it/s, recon=0.132, kld=0.032] 


Epoch 8: recon=0.13094, kld=0.03165, saved: recon_epoch8.png, samples_epoch8.png


CVAE epoch 9: 100%|██████████| 105/105 [00:21<00:00,  4.87it/s, recon=0.13, kld=0.0318] 


Epoch 9: recon=0.12901, kld=0.03150, saved: recon_epoch9.png, samples_epoch9.png


CVAE epoch 10: 100%|██████████| 105/105 [00:34<00:00,  3.01it/s, recon=0.132, kld=0.0325]


Epoch 10: recon=0.12800, kld=0.03158, saved: recon_epoch10.png, samples_epoch10.png


CVAE epoch 11: 100%|██████████| 105/105 [00:27<00:00,  3.86it/s, recon=0.127, kld=0.0314]


Epoch 11: recon=0.12697, kld=0.03142, saved: recon_epoch11.png, samples_epoch11.png


CVAE epoch 12: 100%|██████████| 105/105 [00:32<00:00,  3.21it/s, recon=0.126, kld=0.0313]


Epoch 12: recon=0.12615, kld=0.03134, saved: recon_epoch12.png, samples_epoch12.png


CVAE epoch 13: 100%|██████████| 105/105 [00:45<00:00,  2.31it/s, recon=0.127, kld=0.0314]


Epoch 13: recon=0.12578, kld=0.03107, saved: recon_epoch13.png, samples_epoch13.png


CVAE epoch 14: 100%|██████████| 105/105 [00:42<00:00,  2.47it/s, recon=0.127, kld=0.0315]


Epoch 14: recon=0.12548, kld=0.03116, saved: recon_epoch14.png, samples_epoch14.png


CVAE epoch 15: 100%|██████████| 105/105 [00:25<00:00,  4.06it/s, recon=0.126, kld=0.0318]


Epoch 15: recon=0.12498, kld=0.03149, saved: recon_epoch15.png, samples_epoch15.png


CVAE epoch 16: 100%|██████████| 105/105 [00:25<00:00,  4.08it/s, recon=0.128, kld=0.0315]


Epoch 16: recon=0.12476, kld=0.03064, saved: recon_epoch16.png, samples_epoch16.png


CVAE epoch 17: 100%|██████████| 105/105 [00:22<00:00,  4.62it/s, recon=0.125, kld=0.0311]


Epoch 17: recon=0.12457, kld=0.03108, saved: recon_epoch17.png, samples_epoch17.png


CVAE epoch 18: 100%|██████████| 105/105 [00:22<00:00,  4.58it/s, recon=0.125, kld=0.031] 


Epoch 18: recon=0.12430, kld=0.03071, saved: recon_epoch18.png, samples_epoch18.png


CVAE epoch 19: 100%|██████████| 105/105 [00:24<00:00,  4.36it/s, recon=0.126, kld=0.0312]


Epoch 19: recon=0.12369, kld=0.03058, saved: recon_epoch19.png, samples_epoch19.png


CVAE epoch 20: 100%|██████████| 105/105 [00:28<00:00,  3.66it/s, recon=0.125, kld=0.0311]


Epoch 20: recon=0.12349, kld=0.03078, saved: recon_epoch20.png, samples_epoch20.png


CVAE epoch 21: 100%|██████████| 105/105 [00:29<00:00,  3.53it/s, recon=0.123, kld=0.0303]


Epoch 21: recon=0.12322, kld=0.03030, saved: recon_epoch21.png, samples_epoch21.png


CVAE epoch 22: 100%|██████████| 105/105 [00:26<00:00,  3.99it/s, recon=0.125, kld=0.0312]


Epoch 22: recon=0.12300, kld=0.03065, saved: recon_epoch22.png, samples_epoch22.png


CVAE epoch 23: 100%|██████████| 105/105 [00:23<00:00,  4.40it/s, recon=0.125, kld=0.0309]


Epoch 23: recon=0.12283, kld=0.03032, saved: recon_epoch23.png, samples_epoch23.png


CVAE epoch 24: 100%|██████████| 105/105 [00:22<00:00,  4.77it/s, recon=0.123, kld=0.0309]


Epoch 24: recon=0.12220, kld=0.03056, saved: recon_epoch24.png, samples_epoch24.png


CVAE epoch 25: 100%|██████████| 105/105 [00:21<00:00,  4.89it/s, recon=0.122, kld=0.0309]


Epoch 25: recon=0.12171, kld=0.03095, saved: recon_epoch25.png, samples_epoch25.png


CVAE epoch 26: 100%|██████████| 105/105 [00:26<00:00,  3.91it/s, recon=0.125, kld=0.0327]


Epoch 26: recon=0.12106, kld=0.03175, saved: recon_epoch26.png, samples_epoch26.png


CVAE epoch 27: 100%|██████████| 105/105 [00:27<00:00,  3.82it/s, recon=0.123, kld=0.0318]


Epoch 27: recon=0.12040, kld=0.03124, saved: recon_epoch27.png, samples_epoch27.png


CVAE epoch 28: 100%|██████████| 105/105 [00:26<00:00,  4.00it/s, recon=0.122, kld=0.0329]


Epoch 28: recon=0.12045, kld=0.03259, saved: recon_epoch28.png, samples_epoch28.png


CVAE epoch 29: 100%|██████████| 105/105 [00:23<00:00,  4.42it/s, recon=0.121, kld=0.0314]


Epoch 29: recon=0.11987, kld=0.03114, saved: recon_epoch29.png, samples_epoch29.png


CVAE epoch 30: 100%|██████████| 105/105 [00:23<00:00,  4.39it/s, recon=0.123, kld=0.0322]


Epoch 30: recon=0.11944, kld=0.03124, saved: recon_epoch30.png, samples_epoch30.png


CVAE epoch 31: 100%|██████████| 105/105 [00:25<00:00,  4.16it/s, recon=0.123, kld=0.0317]


Epoch 31: recon=0.11939, kld=0.03080, saved: recon_epoch31.png, samples_epoch31.png


CVAE epoch 32: 100%|██████████| 105/105 [00:21<00:00,  4.93it/s, recon=0.122, kld=0.0319]


Epoch 32: recon=0.11937, kld=0.03129, saved: recon_epoch32.png, samples_epoch32.png


CVAE epoch 33: 100%|██████████| 105/105 [00:22<00:00,  4.77it/s, recon=0.123, kld=0.0431]


Epoch 33: recon=0.12154, kld=0.04269, saved: recon_epoch33.png, samples_epoch33.png


CVAE epoch 34: 100%|██████████| 105/105 [00:25<00:00,  4.07it/s, recon=0.119, kld=0.0306]


Epoch 34: recon=0.11908, kld=0.03058, saved: recon_epoch34.png, samples_epoch34.png


CVAE epoch 35: 100%|██████████| 105/105 [00:25<00:00,  4.15it/s, recon=0.121, kld=0.0313]


Epoch 35: recon=0.11887, kld=0.03074, saved: recon_epoch35.png, samples_epoch35.png


CVAE epoch 36: 100%|██████████| 105/105 [00:20<00:00,  5.11it/s, recon=0.12, kld=0.031]  


Epoch 36: recon=0.11877, kld=0.03070, saved: recon_epoch36.png, samples_epoch36.png


CVAE epoch 37: 100%|██████████| 105/105 [00:17<00:00,  6.00it/s, recon=0.12, kld=0.0311] 


Epoch 37: recon=0.11857, kld=0.03078, saved: recon_epoch37.png, samples_epoch37.png


CVAE epoch 38: 100%|██████████| 105/105 [00:20<00:00,  5.15it/s, recon=0.118, kld=0.0308]


Epoch 38: recon=0.11812, kld=0.03080, saved: recon_epoch38.png, samples_epoch38.png


CVAE epoch 39: 100%|██████████| 105/105 [00:17<00:00,  6.05it/s, recon=0.119, kld=0.0309]


Epoch 39: recon=0.11862, kld=0.03087, saved: recon_epoch39.png, samples_epoch39.png


CVAE epoch 40: 100%|██████████| 105/105 [00:18<00:00,  5.59it/s, recon=0.119, kld=0.031] 


Epoch 40: recon=0.11795, kld=0.03072, saved: recon_epoch40.png, samples_epoch40.png


CVAE epoch 41: 100%|██████████| 105/105 [00:21<00:00,  4.97it/s, recon=0.12, kld=0.0316] 


Epoch 41: recon=0.11809, kld=0.03098, saved: recon_epoch41.png, samples_epoch41.png


CVAE epoch 42: 100%|██████████| 105/105 [00:22<00:00,  4.67it/s, recon=0.121, kld=0.0319]


Epoch 42: recon=0.11770, kld=0.03099, saved: recon_epoch42.png, samples_epoch42.png


CVAE epoch 43: 100%|██████████| 105/105 [00:21<00:00,  4.93it/s, recon=0.118, kld=0.031] 


Epoch 43: recon=0.11778, kld=0.03098, saved: recon_epoch43.png, samples_epoch43.png


CVAE epoch 44: 100%|██████████| 105/105 [00:21<00:00,  4.78it/s, recon=0.118, kld=0.0314]


Epoch 44: recon=0.11718, kld=0.03110, saved: recon_epoch44.png, samples_epoch44.png


CVAE epoch 45: 100%|██████████| 105/105 [00:48<00:00,  2.18it/s, recon=0.12, kld=0.0317] 


Epoch 45: recon=0.11733, kld=0.03107, saved: recon_epoch45.png, samples_epoch45.png


CVAE epoch 46: 100%|██████████| 105/105 [00:29<00:00,  3.55it/s, recon=0.121, kld=0.0322]


Epoch 46: recon=0.11737, kld=0.03128, saved: recon_epoch46.png, samples_epoch46.png


CVAE epoch 47: 100%|██████████| 105/105 [00:39<00:00,  2.68it/s, recon=0.119, kld=0.0317]


Epoch 47: recon=0.11677, kld=0.03110, saved: recon_epoch47.png, samples_epoch47.png


CVAE epoch 48: 100%|██████████| 105/105 [00:48<00:00,  2.17it/s, recon=0.119, kld=0.032] 


Epoch 48: recon=0.11682, kld=0.03139, saved: recon_epoch48.png, samples_epoch48.png


CVAE epoch 49: 100%|██████████| 105/105 [00:37<00:00,  2.83it/s, recon=0.117, kld=0.0315]


Epoch 49: recon=0.11636, kld=0.03124, saved: recon_epoch49.png, samples_epoch49.png


CVAE epoch 50: 100%|██████████| 105/105 [00:32<00:00,  3.20it/s, recon=0.118, kld=0.0318]


Epoch 50: recon=0.11623, kld=0.03120, saved: recon_epoch50.png, samples_epoch50.png


CVAE epoch 51: 100%|██████████| 105/105 [00:50<00:00,  2.10it/s, recon=0.118, kld=0.0319]


Epoch 51: recon=0.11606, kld=0.03130, saved: recon_epoch51.png, samples_epoch51.png


CVAE epoch 52: 100%|██████████| 105/105 [00:49<00:00,  2.11it/s, recon=0.116, kld=0.0314]


Epoch 52: recon=0.11592, kld=0.03138, saved: recon_epoch52.png, samples_epoch52.png


CVAE epoch 53: 100%|██████████| 105/105 [00:40<00:00,  2.61it/s, recon=0.12, kld=0.0321] 


Epoch 53: recon=0.11624, kld=0.03120, saved: recon_epoch53.png, samples_epoch53.png


CVAE epoch 54: 100%|██████████| 105/105 [00:46<00:00,  2.26it/s, recon=0.12, kld=0.0649] 


Epoch 54: recon=0.11950, kld=0.06488, saved: recon_epoch54.png, samples_epoch54.png


CVAE epoch 55: 100%|██████████| 105/105 [00:41<00:00,  2.54it/s, recon=0.118, kld=0.0311]


Epoch 55: recon=0.11786, kld=0.03109, saved: recon_epoch55.png, samples_epoch55.png


CVAE epoch 56: 100%|██████████| 105/105 [00:44<00:00,  2.35it/s, recon=0.117, kld=0.0314]


Epoch 56: recon=0.11632, kld=0.03110, saved: recon_epoch56.png, samples_epoch56.png


CVAE epoch 57: 100%|██████████| 105/105 [00:20<00:00,  5.15it/s, recon=0.118, kld=0.0317]


Epoch 57: recon=0.11604, kld=0.03107, saved: recon_epoch57.png, samples_epoch57.png


CVAE epoch 58: 100%|██████████| 105/105 [00:19<00:00,  5.25it/s, recon=0.117, kld=0.0313]


Epoch 58: recon=0.11603, kld=0.03101, saved: recon_epoch58.png, samples_epoch58.png


CVAE epoch 59: 100%|██████████| 105/105 [00:20<00:00,  5.08it/s, recon=0.118, kld=0.0316]


Epoch 59: recon=0.11582, kld=0.03102, saved: recon_epoch59.png, samples_epoch59.png


CVAE epoch 60: 100%|██████████| 105/105 [00:20<00:00,  5.20it/s, recon=0.119, kld=0.0319]


Epoch 60: recon=0.11568, kld=0.03100, saved: recon_epoch60.png, samples_epoch60.png


CVAE epoch 61: 100%|██████████| 105/105 [00:32<00:00,  3.18it/s, recon=0.115, kld=0.0314]


Epoch 61: recon=0.11541, kld=0.03135, saved: recon_epoch61.png, samples_epoch61.png


CVAE epoch 62: 100%|██████████| 105/105 [00:19<00:00,  5.29it/s, recon=0.117, kld=0.0314]


Epoch 62: recon=0.11540, kld=0.03106, saved: recon_epoch62.png, samples_epoch62.png


CVAE epoch 63: 100%|██████████| 105/105 [00:20<00:00,  5.17it/s, recon=0.117, kld=0.0314]


Epoch 63: recon=0.11540, kld=0.03110, saved: recon_epoch63.png, samples_epoch63.png


CVAE epoch 64: 100%|██████████| 105/105 [00:19<00:00,  5.26it/s, recon=0.116, kld=0.0313]


Epoch 64: recon=0.11524, kld=0.03101, saved: recon_epoch64.png, samples_epoch64.png


CVAE epoch 65: 100%|██████████| 105/105 [00:20<00:00,  5.08it/s, recon=0.119, kld=0.0321]


Epoch 65: recon=0.11527, kld=0.03120, saved: recon_epoch65.png, samples_epoch65.png


CVAE epoch 66: 100%|██████████| 105/105 [00:21<00:00,  4.78it/s, recon=0.115, kld=0.0312]


Epoch 66: recon=0.11497, kld=0.03119, saved: recon_epoch66.png, samples_epoch66.png


CVAE epoch 67: 100%|██████████| 105/105 [00:29<00:00,  3.54it/s, recon=0.115, kld=0.0312]


Epoch 67: recon=0.11491, kld=0.03122, saved: recon_epoch67.png, samples_epoch67.png


CVAE epoch 68: 100%|██████████| 105/105 [00:41<00:00,  2.53it/s, recon=0.115, kld=0.0311]


Epoch 68: recon=0.11497, kld=0.03112, saved: recon_epoch68.png, samples_epoch68.png


CVAE epoch 69: 100%|██████████| 105/105 [00:19<00:00,  5.31it/s, recon=0.115, kld=0.0313]


Epoch 69: recon=0.11491, kld=0.03130, saved: recon_epoch69.png, samples_epoch69.png


CVAE epoch 70: 100%|██████████| 105/105 [00:20<00:00,  5.21it/s, recon=0.118, kld=0.032] 


Epoch 70: recon=0.11462, kld=0.03110, saved: recon_epoch70.png, samples_epoch70.png


CVAE epoch 71: 100%|██████████| 105/105 [00:19<00:00,  5.29it/s, recon=0.117, kld=0.0319]


Epoch 71: recon=0.11453, kld=0.03129, saved: recon_epoch71.png, samples_epoch71.png


CVAE epoch 72: 100%|██████████| 105/105 [00:20<00:00,  5.21it/s, recon=0.117, kld=0.0319]


Epoch 72: recon=0.11444, kld=0.03134, saved: recon_epoch72.png, samples_epoch72.png


CVAE epoch 73: 100%|██████████| 105/105 [00:19<00:00,  5.34it/s, recon=0.116, kld=0.0315]


Epoch 73: recon=0.11442, kld=0.03119, saved: recon_epoch73.png, samples_epoch73.png


CVAE epoch 74: 100%|██████████| 105/105 [00:20<00:00,  5.19it/s, recon=0.115, kld=0.0312]


Epoch 74: recon=0.11455, kld=0.03122, saved: recon_epoch74.png, samples_epoch74.png


CVAE epoch 75: 100%|██████████| 105/105 [00:19<00:00,  5.31it/s, recon=0.114, kld=0.0313]


Epoch 75: recon=0.11430, kld=0.03131, saved: recon_epoch75.png, samples_epoch75.png


CVAE epoch 76: 100%|██████████| 105/105 [00:21<00:00,  4.91it/s, recon=0.115, kld=0.0317]


Epoch 76: recon=0.11412, kld=0.03139, saved: recon_epoch76.png, samples_epoch76.png


CVAE epoch 77: 100%|██████████| 105/105 [00:21<00:00,  4.83it/s, recon=0.116, kld=0.0321]


Epoch 77: recon=0.11403, kld=0.03145, saved: recon_epoch77.png, samples_epoch77.png


CVAE epoch 78: 100%|██████████| 105/105 [00:21<00:00,  4.91it/s, recon=0.115, kld=0.0318]


Epoch 78: recon=0.11369, kld=0.03151, saved: recon_epoch78.png, samples_epoch78.png


CVAE epoch 79: 100%|██████████| 105/105 [00:21<00:00,  4.97it/s, recon=0.115, kld=0.0317]


Epoch 79: recon=0.11398, kld=0.03138, saved: recon_epoch79.png, samples_epoch79.png


CVAE epoch 80: 100%|██████████| 105/105 [00:21<00:00,  4.90it/s, recon=0.117, kld=0.0325]


Epoch 80: recon=0.11358, kld=0.03153, saved: recon_epoch80.png, samples_epoch80.png
Training finished. Models and examples saved to outputs_cvae


In [11]:
# пример использования генерации и визуализации нескольких классов

grid_imgs = []
grid_labels = []
n_per = SAMPLE_PER_CLASS
for cls_idx in range(min(5, NUM_CLASSES)):
    imgs = generate_by_class(model, cls_idx, n=n_per)
    grid_imgs.append(imgs)
    grid_labels += [class_names[cls_idx]] * n_per

# Склеим все в один тензор и сохраним
grid = torch.cat(grid_imgs, dim=0)
save_image_grid(grid, OUT_DIR / "generation_first5_classes.png", nrow=n_per)
print("Saved generation grid for first 5 classes:", OUT_DIR / "generation_first5_classes.png")

Saved generation grid for first 5 classes: outputs_cvae\generation_first5_classes.png


In [12]:
def class_name_to_index(name):
    lowered = [c.lower() for c in class_names]
    name_l = name.lower()
    if name_l in lowered:
        return lowered.index(name_l)
    for i,c in enumerate(lowered):
        if c.startswith(name_l) or name_l in c:
            return i
    raise ValueError("class not found")

def generate_by_name(name, n=6, save_path=None):
    idx = class_name_to_index(name)
    imgs = generate_by_class(model, idx, n=n, save_path=save_path)
    return imgs

In [14]:
generate_by_name("dog", n = 10, save_path=OUT_DIR / "generation_dog.png")


tensor([[[[ 0.0124, -0.0080, -0.0059,  ..., -0.0971, -0.1025, -0.0863],
          [ 0.0140,  0.0281,  0.0162,  ..., -0.1067, -0.1138, -0.0859],
          [ 0.0371,  0.0512,  0.0413,  ..., -0.1023, -0.1044, -0.0994],
          ...,
          [-0.7692, -0.8026, -0.7991,  ..., -0.2506, -0.2410, -0.2230],
          [-0.7388, -0.7971, -0.7912,  ..., -0.2435, -0.2424, -0.2243],
          [-0.6510, -0.7470, -0.7439,  ..., -0.2277, -0.2386, -0.1950]],

         [[ 0.0254,  0.0524,  0.0460,  ..., -0.0526, -0.0622, -0.0706],
          [ 0.0718,  0.0880,  0.0693,  ..., -0.0501, -0.0564, -0.0588],
          [ 0.0763,  0.1024,  0.1119,  ..., -0.0303, -0.0436, -0.0678],
          ...,
          [-0.7571, -0.8004, -0.7883,  ..., -0.2419, -0.2411, -0.2256],
          [-0.7566, -0.7982, -0.8025,  ..., -0.2423, -0.2364, -0.2307],
          [-0.6594, -0.7596, -0.7627,  ..., -0.2615, -0.2501, -0.2247]],

         [[-0.0440, -0.0457, -0.0523,  ..., -0.1836, -0.1927, -0.1924],
          [-0.0232, -0.0304, -