<a href="https://colab.research.google.com/github/Abhiraj36/Compression01/blob/main/CNN_compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# Ensure output directory exists
os.makedirs("output_images", exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Skip-Connected Autoencoder

class SkipAutoencoder(nn.Module):
    def __init__(self):
        super(SkipAutoencoder, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True))
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(inplace=True))  # 16x16
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(inplace=True))  # 8x8

        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True))  # 16x16
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True))  # 32x32
        self.dec3 = nn.Sequential(
            nn.Conv2d(128, 3, 3, padding=1), nn.Tanh())

    def forward(self, x):
        e1 = self.enc1(x)  # [B, 64, 32, 32]
        e2 = self.enc2(e1) # [B, 128, 16, 16]
        e3 = self.enc3(e2) # [B, 256, 8, 8]

        d1 = self.dec1(e3)                # [B, 128, 16, 16]
        d1_cat = torch.cat([d1, e2], 1)   # [B, 256, 16, 16]

        d2 = self.dec2(d1_cat)           # [B, 64, 32, 32]
        d2_cat = torch.cat([d2, e1], 1)  # [B, 128, 32, 32]

        out = self.dec3(d2_cat)          # [B, 3, 32, 32]

        return out, e3  # output and latent



# Latent Compressor

class LatentCompressor(nn.Module):
    def __init__(self, input_channels=256):
        super(LatentCompressor, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 128, kernel_size=3, stride=2, padding=1),  # 8x8 -> 4x4
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1),             # 4x4 -> 2x2
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 128, kernel_size=4, stride=2, padding=1),    # 2x2 -> 4x4
            nn.ReLU(),
            nn.ConvTranspose2d(128, input_channels, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8
        )

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out


# Data Preparation

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # normalize to [-1,1]
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=16, shuffle=False)


# Initialize Models

autoencoder = SkipAutoencoder().to(device)
latent_compressor = LatentCompressor().to(device)


# Train Autoencoder First (optional: or load pretrained)

ae_optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
ae_criterion = nn.MSELoss()

print("\nTraining Autoencoder...")
autoencoder.train()
for epoch in range(5):  # Train AE (or load pre-trained)
    total_loss = 0.0
    for imgs, _ in trainloader:
        imgs = imgs.to(device)
        ae_optimizer.zero_grad()
        recon, _ = autoencoder(imgs)
        loss = ae_criterion(recon, imgs)
        loss.backward()
        ae_optimizer.step()
        total_loss += loss.item() * imgs.size(0)

    print(f"Autoencoder Epoch [{epoch+1}/5], Loss: {total_loss / len(trainloader.dataset):.6f}")





Training Autoencoder...
Autoencoder Epoch [1/5], Loss: 0.005822
Autoencoder Epoch [2/5], Loss: 0.000605
Autoencoder Epoch [3/5], Loss: 0.000380
Autoencoder Epoch [4/5], Loss: 0.000277
Autoencoder Epoch [5/5], Loss: 0.000218


In [12]:

# Train Latent Compressor

latent_compressor.train()
comp_optimizer = optim.Adam(latent_compressor.parameters(), lr=0.001)
comp_criterion = nn.MSELoss()

print("\nTraining Latent Compressor...")
for epoch in range(10):
    total_loss = 0.0
    for imgs, _ in trainloader:
        imgs = imgs.to(device)
        with torch.no_grad():
            _, e3 = autoencoder(imgs)
        comp_optimizer.zero_grad()
        recon_e3 = latent_compressor(e3)
        loss = comp_criterion(recon_e3, e3)
        loss.backward()
        comp_optimizer.step()
        total_loss += loss.item() * imgs.size(0)

    print(f"Compressor Epoch [{epoch+1}/10], Loss: {total_loss / len(trainloader.dataset):.6f}")




Training Latent Compressor...
Compressor Epoch [1/10], Loss: 0.015379
Compressor Epoch [2/10], Loss: 0.006447
Compressor Epoch [3/10], Loss: 0.004807
Compressor Epoch [4/10], Loss: 0.004025
Compressor Epoch [5/10], Loss: 0.003538
Compressor Epoch [6/10], Loss: 0.003223
Compressor Epoch [7/10], Loss: 0.002979
Compressor Epoch [8/10], Loss: 0.002798
Compressor Epoch [9/10], Loss: 0.002653
Compressor Epoch [10/10], Loss: 0.002532


In [13]:


# Evaluation & Saving Images

autoencoder.eval()
latent_compressor.eval()

with torch.no_grad():
    imgs, _ = next(iter(testloader))
    imgs = imgs.to(device)

    # Original latent
    e1 = autoencoder.enc1(imgs)
    e2 = autoencoder.enc2(e1)
    e3 = autoencoder.enc3(e2)

    # Compress e3
    z = latent_compressor.encoder(e3)
    print(f"Compressed shape: {z.shape}")  # e.g., [B,64,2,2]

    # Decompress e3
    decoded_e3 = latent_compressor.decoder(z)

    # Decode full image using decoded_e3, e2, e1
    d1 = autoencoder.dec1(decoded_e3)
    d1_cat = torch.cat([d1, e2], 1)
    d2 = autoencoder.dec2(d1_cat)
    d2_cat = torch.cat([d2, e1], 1)
    recon_from_compressed = autoencoder.dec3(d2_cat)

    # Save all images
    save_image(imgs * 0.5 + 0.5, "output_images/original.png")
    recon_img, _ = autoencoder(imgs)
    save_image(recon_img * 0.5 + 0.5, "output_images/reconstructed.png")
    save_image(recon_from_compressed * 0.5 + 0.5, "output_images/recon_from_compressed.png")

    # Size info
    original_size = e3.nelement() * 4
    compressed_size = z.nelement() * 4
    print(f"Original e3 size: {original_size / 1024:.2f} KB")
    print(f"Compressed z size: {compressed_size / 1024:.2f} KB")

    print("Saved original, reconstructed, and compressed recon images.")


Compressed shape: torch.Size([16, 64, 2, 2])
Original e3 size: 1024.00 KB
Compressed z size: 16.00 KB
Saved original, reconstructed, and compressed recon images.


In [14]:
# Simple quantization by scaling and rounding
z_quantized = torch.round(z * 127)  # Scale latent values roughly to int8 range [-127, 127]
z_quantized = z_quantized.to(torch.int8)  # Convert tensor to int8 type for smaller size


In [15]:
# Convert to NumPy array on CPU and save as raw binary file
z_np = z_quantized.cpu().numpy()
os.makedirs("output_images", exist_ok=True)  # Make sure output folder exists
z_np.tofile("output_images/compressed_latent.bin")

print("Quantized latent saved as binary file: output_images/compressed_latent.bin")


Quantized latent saved as binary file: output_images/compressed_latent.bin


In [16]:
# Calculate and print size of quantized latent in KB
quantized_size = os.path.getsize("output_images/compressed_latent.bin") / 1024
print(f"Quantized latent binary size: {quantized_size:.2f} KB")


Quantized latent binary size: 4.00 KB


In [17]:
# Load quantized latent from binary file
loaded_np = np.fromfile("output_images/compressed_latent.bin", dtype=np.int8)
loaded_tensor = torch.from_numpy(loaded_np).float() / 127  # Scale back to approx original range

# Reshape to latent tensor shape (same as z)
loaded_tensor = loaded_tensor.view_as(z).to(device)

# Decode latent as usual
with torch.no_grad():
    decoded_e3_loaded = latent_compressor.decoder(loaded_tensor)

    d1_loaded = autoencoder.dec1(decoded_e3_loaded)
    d1_cat_loaded = torch.cat([d1_loaded, e2], 1)
    d2_loaded = autoencoder.dec2(d1_cat_loaded)
    d2_cat_loaded = torch.cat([d2_loaded, e1], 1)
    recon_loaded = autoencoder.dec3(d2_cat_loaded)

    save_image(recon_loaded * 0.5 + 0.5, "output_images/recon_from_loaded_compressed.png")
    print("Saved reconstructed image from loaded compressed latent.")


Saved reconstructed image from loaded compressed latent.
