In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.utils
from google.colab import files

# Define the transformation (same as used in training)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def denormalize(tensor):
    return tensor * 0.5 + 0.5

# ---------------------------
# TCN Block (matching training code)
# ---------------------------
class TCNBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=2):
        super(TCNBlock, self).__init__()
        padding = (kernel_size - 1) * dilation // 2
        self.conv1d = nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=dilation)
        self.bn = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x: (B, C, H, W) -> reshape to (B, C, H*W)
        B, C, H, W = x.shape
        x_reshaped = x.view(B, C, H * W)
        out = self.conv1d(x_reshaped)
        out = self.bn(out)
        out = self.relu(out)
        # Reshape back to (B, C, H, W)
        out = out.view(B, C, H, W)
        return out + x  # residual connection

# ---------------------------
# UNetTCNGenerator (matching training code)
# ---------------------------
class UNetTCNGenerator(nn.Module):
    def __init__(self):
        super(UNetTCNGenerator, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 256 -> 128
            nn.LeakyReLU(0.2)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.enc5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), # 16 -> 8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )

        # Bottleneck with TCN blocks
        self.tcn1 = TCNBlock(channels=512, kernel_size=3, dilation=1)
        self.tcn2 = TCNBlock(channels=512, kernel_size=3, dilation=2)
        self.tcn3 = TCNBlock(channels=512, kernel_size=3, dilation=4)

        # Decoder with skip connections
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),  # 8 -> 16
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(512*2, 256, kernel_size=4, stride=2, padding=1),  # 16 -> 32
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(256*2, 128, kernel_size=4, stride=2, padding=1),  # 32 -> 64
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128*2, 64, kernel_size=4, stride=2, padding=1),   # 64 -> 128
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(64*2, 3, kernel_size=4, stride=2, padding=1),    # 128 -> 256
            nn.Tanh()  # output in [-1, 1]
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)    # (B,64,128,128)
        e2 = self.enc2(e1)   # (B,128,64,64)
        e3 = self.enc3(e2)   # (B,256,32,32)
        e4 = self.enc4(e3)   # (B,512,16,16)
        e5 = self.enc5(e4)   # (B,512,8,8)

        # Bottleneck with TCN blocks
        b = self.tcn1(e5)
        b = self.tcn2(b)
        b = self.tcn3(b)

        # Decoder with skip connections
        d5 = self.dec5(b)                        # (B,512,16,16)
        d4 = self.dec4(torch.cat([d5, e4], 1))     # (B,256,32,32)
        d3 = self.dec3(torch.cat([d4, e3], 1))     # (B,128,64,64)
        d2 = self.dec2(torch.cat([d3, e2], 1))     # (B,64,128,128)
        d1 = self.dec1(torch.cat([d2, e1], 1))     # (B,3,256,256)

        return d1

# ---------------------------
# Helper functions for file handling in Colab
# ---------------------------
def upload_file(prompt_text):
    print(prompt_text)
    uploaded = files.upload()
    if not uploaded:
        raise FileNotFoundError("No file was uploaded!")
    # Return the first (and usually only) uploaded file name
    return list(uploaded.keys())[0]

def ensure_file_exists(file_path, upload_prompt):
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        print("Please upload the file now.")
        file_path = upload_file(upload_prompt)
    return file_path

# ---------------------------
# Model Loading and Inference Functions
# ---------------------------
def load_generator(model_path, device):
    generator = UNetTCNGenerator().to(device)
    state_dict = torch.load(model_path, map_location=device)
    generator.load_state_dict(state_dict)
    generator.eval()
    return generator

def cartoonize_image(model_path, image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Ensure both files exist (or get them via upload)
    model_path = ensure_file_exists(model_path, "Upload your trained generator model (.pth file):")
    image_path = ensure_file_exists(image_path, "Upload your real face image:")

    # Load the generator model
    generator = load_generator(model_path, device)

    # Load and process the input image
    input_image = Image.open(image_path).convert("RGB")
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    # Generate the cartoonized version
    with torch.no_grad():
        cartoon_tensor = generator(input_tensor)

    # Denormalize and convert to PIL image
    cartoon_tensor = denormalize(cartoon_tensor.squeeze(0)).cpu().clamp(0, 1)
    cartoon_image = transforms.ToPILImage()(cartoon_tensor)

    # Display side-by-side
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(input_image)
    plt.title("Original Real Face")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(cartoon_image)
    plt.title("Cartoonized Face")
    plt.axis("off")

    plt.show()
    return cartoon_image

# ---------------------------
# Main Execution
# ---------------------------
if __name__ == "__main__":
    # Specify default file paths
    default_model_path = ''
    default_image_path = ''

    # Attempt to cartoonize using the specified file paths or prompt for upload if not found
    cartoonized_image = cartoonize_image(default_model_path, default_image_path)


File not found: 
Please upload the file now.
Upload your trained generator model (.pth file):
