In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import ImageFolder
import os
from PIL import Image
import numpy as np
import tqdm as tqdm

In [2]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        # Decoder layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()  # Output pixel values between 0 and 1
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [4]:
def transform_test_images_to_tensor(path):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    
    files = os.listdir(path)
    images = []
    for i in files:
        img_tensor = transform((Image.open(path + i)).convert('L'))
        images.append(img_tensor)
    return images

In [5]:
# Directory path containing the images
directory = 'TestImage/'

image_shapes = []
# Iterate over each file in the directory
for filename in os.listdir(directory):
    # Construct the full file path
    file_path = os.path.join(directory, filename)
    
    # Open the image using PIL
    image = Image.open(file_path)
    
    # Get the shape of the image
    image_shapes.append(image.size)



images = transform_test_images_to_tensor('TestImage/')
images = torch.stack(images)
print(images.shape)

model = Autoencoder()  # Replace with your model architecture

model_path = 'model.pth'
model = torch.load(model_path , map_location=torch.device('cpu'))

output = model(images)
for i in range(len(output)):
    resized_image = F.interpolate(output[i].unsqueeze(0), size=image_shapes[i], mode='bilinear', align_corners=False)
    torchvision.utils.save_image(resized_image, 'TestSketch/Sketch'+str(i)+'.png')

torch.Size([4, 1, 512, 512])
