In [None]:
import os
import torch
import torchvision
import numpy as np
import PIL
from packaging import version
import torch.nn as nn
import torch.optim as optim
# Generator (U-Net with skip connections)
class UNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=1, ngf=64):
        super(UNetGenerator, self).__init__()

        def down_block(in_channels, out_channels, apply_batchnorm=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False)]
            if apply_batchnorm:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        def up_block(in_channels, out_channels, apply_dropout=False):
            layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
                      nn.BatchNorm2d(out_channels),
                      nn.ReLU(inplace=True)]
            if apply_dropout:
                layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)

        # Encoder (Downsampling layers)
        self.down1 = down_block(input_nc, ngf, apply_batchnorm=False)
        self.down2 = down_block(ngf, ngf * 2)
        self.down3 = down_block(ngf * 2, ngf * 4)
        self.down4 = down_block(ngf * 4, ngf * 8)
        self.down5 = down_block(ngf * 8, ngf * 8)
        self.down6 = down_block(ngf * 8, ngf * 8)
        self.down7 = down_block(ngf * 8, ngf * 8)
        self.down8 = down_block(ngf * 8, ngf * 8, apply_batchnorm=False)

        # Decoder (Upsampling layers)
        self.up1 = up_block(ngf * 8, ngf * 8, apply_dropout=True)
        self.up2 = up_block(ngf * 8 * 2, ngf * 8, apply_dropout=True)
        self.up3 = up_block(ngf * 8 * 2, ngf * 8, apply_dropout=True)
        self.up4 = up_block(ngf * 8 * 2, ngf * 8)
        self.up5 = up_block(ngf * 8 * 2, ngf * 4)
        self.up6 = up_block(ngf * 4 * 2, ngf * 2)
        self.up7 = up_block(ngf * 2 * 2, ngf)
        self.final = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        return self.final(torch.cat([u7, d1], 1))


# PatchGAN Discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self, input_nc=3, output_nc=1, ndf=64):
        super(PatchGANDiscriminator, self).__init__()

        def disc_block(in_channels, out_channels, stride=2, apply_batchnorm=True):
            layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False)]
            if apply_batchnorm:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        self.model = nn.Sequential(
            disc_block(input_nc + output_nc, ndf, apply_batchnorm=False),
            disc_block(ndf, ndf * 2),
            disc_block(ndf * 2, ndf * 4),
            disc_block(ndf * 4, ndf * 8, stride=1),
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, img, sketch):
        input = torch.cat((img, sketch), 1)  # Concatenate input and output
        return self.model(input)


# Define paths
CHECKPOINT_PATH = "pix2pix_checkpoint_pencil.pth"

# ------------------------------
# 🔹 1. Check PyTorch & CUDA Versions
# ------------------------------
print("\n🔥 Checking PyTorch & CUDA Versions:")
print(f"🔹 PyTorch Version: {torch.__version__}")
print(f"🔹 Torchvision Version: {torchvision.__version__}")
print(f"🔹 NumPy Version: {np.__version__}")
print(f"🔹 Pillow Version: {PIL.__version__}")

# Check CUDA
cuda_available = torch.cuda.is_available()
print(f"🔹 CUDA Available: {cuda_available}")
if cuda_available:
    print(f"🔹 CUDA Device Count: {torch.cuda.device_count()}")
    print(f"🔹 CUDA Current Device: {torch.cuda.current_device()}")
    print(f"🔹 CUDA Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# ------------------------------
# 🔹 2. Verify Checkpoint File
# ------------------------------
print("\n📂 Checking Checkpoint File:")
if os.path.exists(CHECKPOINT_PATH):
    print(f"✅ Checkpoint found at: {os.path.abspath(CHECKPOINT_PATH)}")
    print(f"✅ Checkpoint Size: {os.path.getsize(CHECKPOINT_PATH) / (1024 * 1024):.2f} MB")
else:
    print("❌ ERROR: Checkpoint file NOT found.")
    exit()

# ------------------------------
# 🔹 3. Load Checkpoint and Verify Keys
# ------------------------------
print("\n📜 Checking Checkpoint Keys:")
try:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
    print(f"✅ Found keys: {list(checkpoint.keys())}")
    
    if "generator_state" in checkpoint:
        print("✅ 'generator_state' is present!")
    else:
        print("❌ ERROR: 'generator_state' key is MISSING!")

except Exception as e:
    print(f"❌ ERROR: Could not load checkpoint! {e}")
    exit()

# ------------------------------
# 🔹 4. Compare Model State Dict
# ------------------------------

print("\n🔍 Checking Model Architecture Compatibility:")
device = torch.device("cuda" if cuda_available else "cpu")
generator = UNetGenerator(input_nc=3, output_nc=1).to(device)

# Load generator state
try:
    generator.load_state_dict(checkpoint["generator_state"], strict=False)
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ ERROR: Model state_dict mismatch: {e}")

# Print model state keys
print("\n🛠️ Model State Dict Keys:")
for name, param in generator.state_dict().items():
    print(f"🔹 {name}: {param.shape}")

# Print checkpoint state keys
print("\n📦 Checkpoint State Dict Keys:")
for key in checkpoint["generator_state"]:
    print(f"🔹 {key}: {checkpoint['generator_state'][key].shape}")

print("\n✅ **Diagnostics Complete!** Compare results with the working machine.")
