In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


In [None]:
class SketchEncoder(nn.Module):
    """Encodes sketch/doodle into a latent representation"""
    def __init__(self):
        super(SketchEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        self.fc = nn.Linear(512*8*8, 256)

In [None]:

        
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
class ImageGenerator(nn.Module):
    """Generates photorealistic image from latent representation"""
    def __init__(self):
        super(ImageGenerator, self).__init__()
        self.fc = nn.Linear(256, 512*8*8)
        self.deconv1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1)
        

In [None]:

    def forward(self, z):
        x = F.relu(self.fc(z))
        x = x.view(x.size(0), 512, 8, 8)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        x = torch.tanh(self.deconv4(x))
        return x

In [None]:
class StyleTransferModule(nn.Module):
    """Applies artistic style while preserving composition"""
    def __init__(self):
        super(StyleTransferModule, self).__init__()
        # Pretrained VGG for style extraction
        self.vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg19', pretrained=True)
        for param in self.vgg.parameters():
            param.requires_grad = False

In [None]:
           
    def forward(self, generated_img, style_img):
        # Extract features at different layers
        gen_features = self._extract_features(generated_img)
        style_features = self._extract_features(style_img)
        
        # Calculate style loss
        style_loss = 0
        for gen_feat, style_feat in zip(gen_features, style_features):
            # Gram matrix calculation
            G = gen_feat @ gen_feat.t()
            A = style_feat @ style_feat.t()
            style_loss += F.mse_loss(G, A)
            
        return style_loss

In [None]:
class DreamSketch(nn.Module):
    """Main model combining all components"""
    def __init__(self):
        super(DreamSketch, self).__init__()
        self.sketch_encoder = SketchEncoder()
        self.image_generator = ImageGenerator()
        self.style_transfer = StyleTransferModule()
        
    def forward(self, sketch, style_img=None):
        z = self.sketch_encoder(sketch)
        generated_img = self.image_generator(z)
        
        if style_img is not None:
            style_loss = self.style_transfer(generated_img, style_img)
            return generated_img, style_loss
        return generated_img

In [None]:
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model
    model = DreamSketch().to(device)
    
    # Loss functions and optimizer
    reconstruction_loss = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
 # Training loop (simplified)
    for epoch in range(100):
        for i, (sketch_batch, photo_batch, style_batch) in enumerate(train_loader):
            sketch_batch = sketch_batch.to(device)
            photo_batch = photo_batch.to(device)
            style_batch = style_batch.to(device)
            
            # Generate image and calculate losses
            generated_img, style_loss = model(sketch_batch, style_batch)
            recon_loss = reconstruction_loss(generated_img, photo_batch)
            total_loss = recon_loss + 0.1 * style_loss
            
            # Backpropagation
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch}, Loss: {total_loss.item()}")