In [1]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageEnhance
import gradio as gr
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ResNet50 model for grayscale input
def resnet50_for_grayscale(num_classes=526):
    model = models.resnet50(pretrained=False)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


# Generator (U-Net)
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        self.down1 = self.conv_block(1, 64, False)
        self.down2 = self.conv_block(64, 128)
        self.down3 = self.conv_block(128, 256)
        self.down4 = self.conv_block(256, 512)
        self.down5 = self.conv_block(512, 512)
        self.up1 = self.upconv_block(512, 512)
        self.up2 = self.upconv_block(1024, 256)
        self.up3 = self.upconv_block(512, 128)
        self.up4 = self.upconv_block(256, 64)
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def conv_block(self, in_channels, out_channels, batch_norm=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))
        return nn.Sequential(*layers)

    def upconv_block(self, in_channels, out_channels):
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        ]
        return nn.Sequential(*layers)

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        u1 = self.up1(d5)
        u2 = self.up2(torch.cat([u1, d4], 1))
        u3 = self.up3(torch.cat([u2, d3], 1))
        u4 = self.up4(torch.cat([u3, d2], 1))
        return self.final(torch.cat([u4, d1], 1))

# Define the Discriminator (PatchGAN)
class PatchDiscriminator(nn.Module):
    def __init__(self):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            self.conv_block(2, 64, False),
            self.conv_block(64, 128),
            self.conv_block(128, 256),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def conv_block(self, in_channels, out_channels, batch_norm=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [3]:
# Load the trained U-Net model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
unet_model_path = 'final_model/generator_epoch_100.pth'
unet_model = UNetGenerator().to(device)
unet_model.load_state_dict(torch.load(unet_model_path, map_location=device))
unet_model.eval()

# Load the ResNet50 model for classification
resnet_model = resnet50_for_grayscale(num_classes=526)
resnet_model.load_state_dict(torch.load('resnet50_model.pth', map_location=device))
resnet_model.to(device)
resnet_model.eval()

# Generate mdvs_to_class_mapping and reverse mapping (you may need to adjust the paths)
image_dir = 'medieval_sinhala'
mdvs_present = []

# Extract mdvsXX from filenames like 'mdvsXX.jpg'
for filename in os.listdir(image_dir):
    if filename.startswith('mdvs') and filename.endswith('.jpg'):
        number_part = int(filename.split('mdvs')[1].split('.jpg')[0].strip())
        mdvs_present.append(number_part)

mdvs_present = sorted(mdvs_present)
mdvs_to_class_mapping = {mdvs_number: idx for idx, mdvs_number in enumerate(mdvs_present)}
class_to_mdvs_mapping = {v: k for k, v in mdvs_to_class_mapping.items()}  # Reverse mapping


# Define image transformation with interpolation for better quality
transform = transforms.Compose([
    transforms.Resize((64, 64), interpolation=Image.BILINEAR),  # Use BILINEAR or LANCZOS interpolation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Function to get classified image using reverse mapping
def get_classified_image(label):
    try:
        # Get the actual mdvs number using the reverse mapping
        mdvs_number = class_to_mdvs_mapping.get(label)

        if mdvs_number is None:
            raise ValueError(f"Label {label} not found in reverse mapping")

        # Load the image from the real dataset
        img_path = f"medieval_sinhala/mdvs{mdvs_number} .jpg"
        if os.path.exists(img_path):
            return Image.open(img_path).convert('L')
        else:
            raise FileNotFoundError(f"Image for mdvs number {mdvs_number} not found.")
    except Exception as e:
        print(e)
        return Image.new('L', (64, 64))  # Return blank image in case of error


# Function to resize the image to 128x128
def resize_image_to_128x128(img):
    return img.resize((128, 128), Image.LANCZOS)  # Resizing to 128x128 using LANCZOS interpolation

# Function to sharpen the image
def sharpen_image(img):
    enhancer = ImageEnhance.Sharpness(img)
    return enhancer.enhance(2.0)  # Increase sharpness by a factor of 2.0

# Function to reconstruct the uploaded image and classify it
def reconstruct_and_classify(img):
    # Convert PIL image to tensor and preprocess
    img = img.convert('L')
    img = transform(img).unsqueeze(0).to(device)
    
    # Perform reconstruction with the U-Net model
    with torch.no_grad():
        reconstructed_img = unet_model(img)

    # Post-process the reconstructed image for display
    reconstructed_img_display = torch.clamp(reconstructed_img * 0.5 + 0.5, 0, 1)  # Denormalize and clamp to [0, 1]
    reconstructed_pil = transforms.ToPILImage()(reconstructed_img_display.cpu().squeeze(0))
    
    # Resize the reconstructed image to 128x128
    reconstructed_pil_resized = resize_image_to_128x128(reconstructed_pil)
    
    # Apply sharpening to the resized image
    reconstructed_pil_sharpened = sharpen_image(reconstructed_pil_resized)

    # Classify the reconstructed image using ResNet50
    with torch.no_grad():
        outputs = resnet_model(reconstructed_img)
        _, predicted_label = torch.max(outputs, 1)
    
    # Get the classified image using the predicted label
    classified_image = get_classified_image(predicted_label.item())

    # Resize the classified image to 128x128
    classified_image_resized = resize_image_to_128x128(classified_image)
    
    return reconstructed_pil_sharpened, classified_image_resized

# Gradio interface to show reconstructed image and classification result
interface = gr.Interface(
    fn=reconstruct_and_classify, 
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="pil", label="Reconstructed Image"), gr.Image(type="pil", label="Classified Image")]  # Outputs: Reconstructed and Classified Image
)

# Launch the interface within the notebook
interface.launch(inline=True)

  unet_model.load_state_dict(torch.load(unet_model_path, map_location=device))
  resnet_model.load_state_dict(torch.load('resnet50_model.pth', map_location=device))


* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


