In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from typing import Tuple

import importlib
import resnet_vae
importlib.reload(resnet_vae)

from resnet_vae import ResNetVAE, vae_loss_function

In [2]:
class ResNetEncoder(nn.Module):
    """
    Uses torchvision ResNet as encoder and returns a spatial feature map.
    If pretrained=True, downloads imagenet weights via torchvision.
    output shape (B, C_enc, H/32, W/32) for typical ResNets (resnet18/34/50).
    """
    def __init__(self, backbone: str = "resnet50", pretrained: bool = True, replace_stride_with_dilation=None):
        super().__init__()
        assert backbone in ("resnet18", "resnet34", "resnet50"), "supported: resnet18, resnet34, resnet50"
        if backbone == "resnet18":
            base = models.resnet18(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
            out_channels = 512
        elif backbone == "resnet34":
            base = models.resnet34(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
            out_channels = 512
        else:
            base = models.resnet50(pretrained=pretrained, replace_stride_with_dilation=replace_stride_with_dilation)
            out_channels = 2048

        # Keep the layers up to layer4 (exclude avgpool & fc)
        # We'll reuse conv1, bn1, relu, maxpool, layer1..layer4
        self.stem = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool
        )
        self.layer1 = base.layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4

        self.out_channels = out_channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, 3, H, W)
        returns: (B, out_channels, H/32, W/32) for default ResNet strides
        """
        x = self.stem(x)      # /4 then maxpool => /4
        x = self.layer1(x)    # /4
        x = self.layer2(x)    # /8
        x = self.layer3(x)    # /16
        x = self.layer4(x)    # /32
        return x


In [3]:
class ConvDecoder(nn.Module):
    def __init__(self, in_channels, out_channels=3):
        super().__init__()
        layers = []
        ch = in_channels
        for _ in range(5):  # upsample ×2 five times (32× → 1×)
            layers += [
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                nn.Conv2d(ch, max(ch // 2, 64), 3, padding=1),
                nn.BatchNorm2d(max(ch // 2, 64)),
                nn.ReLU(inplace=True),
            ]
            ch = max(ch // 2, 64)
        layers += [nn.Conv2d(ch, out_channels, 3, padding=1), nn.Sigmoid()]
        self.decoder = nn.Sequential(*layers)

    def forward(self, z):
        return self.decoder(z)


In [4]:
class ResNetVAE(nn.Module):
    def __init__(self, backbone="resnet18", pretrained=True, latent_dim=256):
        super().__init__()
        self.encoder = ResNetEncoder(backbone=backbone, pretrained=pretrained)
        self.enc_out_ch = self.encoder.out_channels
        self.latent_dim = latent_dim

        # Project encoder output to latent mean & logvar
        self.fc_mu = nn.Linear(self.enc_out_ch * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(self.enc_out_ch * 8 * 8, latent_dim)

        # Project latent back to feature map for decoder
        self.fc_decode = nn.Linear(latent_dim, self.enc_out_ch * 8 * 8)

        self.decoder = ConvDecoder(self.enc_out_ch, out_channels=3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        b = x.size(0)
        h = self.encoder(x)  # (B, C, 8, 8) for 256×256 inputs
        h_flat = h.view(b, -1)

        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        z = self.reparameterize(mu, logvar)

        # decode
        h_dec = self.fc_decode(z).view(b, self.enc_out_ch, 8, 8)
        recon = self.decoder(h_dec)
        return recon, mu, logvar


In [5]:
def vae_loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')
    # KL divergence: 0.5 * sum(μ² + σ² - logσ² - 1)
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 0.001 * kld, recon_loss, kld


In [6]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae = ResNetVAE(backbone="resnet18", pretrained=True, latent_dim=256).to(device)
    opt = torch.optim.Adam(vae.parameters(), lr=1e-4)

    dummy = torch.rand(4, 3, 256, 256).to(device)
    recon, mu, logvar = vae(dummy)
    loss, rec, kld = vae_loss_function(recon, dummy, mu, logvar)
    loss.backward()
    opt.step()
    print(f"Training step: total={loss.item():.4f}, rec={rec.item():.4f}, kl={kld.item():.4f}")

    # Generate new images
    vae.eval()
    with torch.no_grad():
        z = torch.randn(8, vae.latent_dim).to(device)
        dec = vae.fc_decode(z).view(-1, vae.enc_out_ch, 8, 8)
        samples = vae.decoder(dec)
    print("Generated images:", samples.shape)



Training step: total=0.0926, rec=0.0920, kl=0.6329
Generated images: torch.Size([8, 3, 256, 256])


In [7]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Load dataset
dataset = datasets.ImageFolder(root="/Users/poulam/Desktop/Generative Bullshit/archive2/Hands", transform=transform)

# DataLoader
train_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)

In [8]:
from resnet_vae import ResNetVAE

model = ResNetVAE(latent_dim=128)

In [22]:
import sys
import importlib
sys.path.append("/Users/poulam/Desktop/Generative Bullshit")
import resnet_vae
importlib.reload(resnet_vae)
from resnet_vae import ResNetVAE, vae_loss_function

In [23]:
import torch
from resnet_vae import ResNetVAE, vae_loss_function

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae = ResNetVAE(backbone="resnet18", pretrained=True, latent_dim=256, latent_spatial=8).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

In [24]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from resnet_vae import ResNetVAE, vae_loss_function

# assume you already have dataset and dataloader
for epoch in range(50):
    vae.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/50")
    for imgs, _ in pbar:
        imgs = imgs.to(device)
        recon, mu, logvar = vae(imgs)
        loss, rec_loss, kl_loss = vae_loss_function(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({
            "Total": f"{loss.item():.2f}",
            "Recon": f"{rec_loss.item():.2f}",
            "KL": f"{kl_loss.item():.2f}"
        })

Epoch 1/50: 100%|█| 1385/1385 [22:50<00:00,  1.01it/s, Total=8753.78, Recon=7354
Epoch 2/50: 100%|█| 1385/1385 [22:46<00:00,  1.01it/s, Total=6532.63, Recon=5288
Epoch 3/50: 100%|█| 1385/1385 [24:19<00:00,  1.05s/it, Total=5944.91, Recon=4863
Epoch 4/50: 100%|█| 1385/1385 [25:41<00:00,  1.11s/it, Total=3593.92, Recon=2800
Epoch 5/50: 100%|█| 1385/1385 [3:03:14<00:00,  7.94s/it, Total=3478.49, Recon=24
Epoch 6/50: 100%|█| 1385/1385 [20:54<00:00,  1.10it/s, Total=3778.11, Recon=2970
Epoch 7/50: 100%|█| 1385/1385 [20:43<00:00,  1.11it/s, Total=3396.31, Recon=2742
Epoch 8/50: 100%|█| 1385/1385 [20:24<00:00,  1.13it/s, Total=2558.42, Recon=1957
Epoch 9/50: 100%|█| 1385/1385 [38:33<00:00,  1.67s/it, Total=2370.21, Recon=1741
Epoch 10/50: 100%|█| 1385/1385 [22:13<00:00,  1.04it/s, Total=2831.01, Recon=214
Epoch 11/50: 100%|█| 1385/1385 [23:46<00:00,  1.03s/it, Total=2294.37, Recon=159
Epoch 12/50: 100%|█| 1385/1385 [22:39<00:00,  1.02it/s, Total=2399.12, Recon=170
Epoch 13/50: 100%|█| 1385/13

In [None]:
x = torch.randn(1, 3, 256, 256)
out = vae.encoder(x)
print(out.shape)

torch.Size([1, 512, 8, 8])


In [27]:
torch.save(vae.state_dict(), "vae_weights.pth")

In [32]:
import torch
from torchvision.utils import save_image
from resnet_vae import ResNetVAE

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your trained VAE model
vae = ResNetVAE(latent_dim=256).to(device)
vae.load_state_dict(torch.load("/Users/poulam/Desktop/Generative Bullshit/vae_weights.pth", map_location=device))
vae.eval()

# Generate a random latent vector
z = torch.randn(1, vae.latent_dim).to(device)

# Decode it to an image
with torch.no_grad():
    generated = vae.decode(z)

# Save the generated image to your project directory
save_path = "/Users/poulam/Desktop/Generative Bullshit/generated_sample.png"
save_image(generated, save_path)

print(f"✅ Image successfully saved at: {save_path}")

AttributeError: 'ResNetVAE' object has no attribute 'latent_dim'