In [7]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import seaborn as sns
import io
import base64

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
import tqdm 

import time
import torch.optim as optim
import os
import torch_sgld
import json

In [3]:
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, 
                                            transform=torchvision.transforms.ToTensor()) 

In [11]:
torch.manual_seed(42)
np.random.seed(42)

# Parâmetros
input_dim = 28 * 28
hidden_dim = 256
batch_size = 128
max_training_iterations = 1
sgld_lr = 1e-4
ebm_lr = 1e-4
chain_length_per_epoch = 20
final_sampling_steps = 500
num_generated_samples = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Carregamento dos dados MNIST
mnist_trainset = torchvision.datasets.MNIST(root='./data', train = True, download = True,
                                            transform = transform)
mnist_train_loader = DataLoader(mnist_trainset, batch_size = batch_size, shuffle = True)

# Inicializando o modelo de energia
E = nn.Sequential(nn.Conv2d(1, 64, kernel_size = 5, stride = 2, padding = 2),
                  nn.LeakyReLU(0.2),
                  nn.Conv2d(64, 128, kernel_size = 5, stride = 2, padding = 2),
                  nn.LeakyReLU(0.2),
                  nn.Conv2d(128, 256, kernel_size = 5, stride = 2, padding = 2),
                  nn.LeakyReLU(0.2),
                  nn.Flatten(),
                  nn.Linear(256 * 4 * 4, 1)).to(device)

optimizer = optim.Adam(E.parameters(), lr = ebm_lr)

# Inicializando as amostras "x" para SGLD
x_sgld = torch.randn((batch_size, 1, 28, 28), device = device)
x_sgld.requires_grad = True
sampler = torch_sgld.SGLD([x_sgld], lr = sgld_lr)

# Iterador para os dados reais
data_iterator = iter(mnist_train_loader)

if not os.path.exists("energy_model_mnist2.pth"):
    for i in tqdm.tqdm(range(max_training_iterations), desc = "Treinando EBM"):
        E.zero_grad()

        # Pega um lote de dados reais
        try:
            real_images, _ = next(data_iterator)
        except StopIteration:
            # Reinicia o iterador
            data_iterator = iter(mnist_train_loader)
            real_images, _ = next(data_iterator)

        real_images = real_images.to(device)

        # Passo da amostragem
        for _ in range(chain_length_per_epoch):
            sampler.zero_grad()
            potential = E(x_sgld)
            potential.sum().backward()
            sampler.step()
            x_sgld.data.clamp_(-1.0, 1.0)

        # Passo da otimização
        optimizer.zero_grad()

        # Energia das amostras reais
        positive_energy = E(real_images).mean()
        # Energia das amostras geradas
        negative_energy = E(x_sgld.detach()).mean()

        # Perda MLE
        loss = positive_energy - negative_energy
        loss.backward()
        torch.nn.utils.clip_grad_norm_(E.parameters(), max_norm = 1.0)
        optimizer.step()

    torch.save(E.state_dict(), "energy_model_mnist2.pth")
    with open("generated_images.json", "w") as f:
        json.dump(x_sgld.detach().cpu().numpy().tolist(), f)

Treinando EBM: 100%|██████████| 1/1 [00:07<00:00,  7.74s/it]
