In [5]:
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import logging
import matplotlib.pyplot as plt

# Set up logging
logging.basicConfig(filename='app1.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 3 channels for RGB
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
        )
        
        self.fc_layers = nn.Sequential(
            nn.Linear(128 * 4 * 4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

def load_model():
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = VGG().to(device)
        
        checkpoint = torch.load('./models/best_model.pth', map_location=device)
        
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint) 
            
        model.eval()
        return model, device
    except Exception as e:
        logging.error(f"Error loading model: {e}")
        raise

def load_and_display_image(input_image):
    if isinstance(input_image, dict) and 'composite' in input_image:
        image = np.array(input_image['composite'])
    else:
        raise ValueError("Input image is not in the expected format.")
    
    # Check if the image is grayscale or RGB
    if image.ndim == 2:  # Grayscale image
        plt.imshow(image, cmap='gray')
    elif image.ndim == 3 and image.shape[2] == 3:  # RGB image
        plt.imshow(image)  # No colormap needed for RGB
    else:
        raise ValueError("Input image is not in a recognized format (grayscale or RGB).")
    
    plt.title('Original Input Image')
    plt.axis('off')
    plt.show()

def preprocess_image(input_image):
    try:
        # Check if the input image is in the expected format
        if isinstance(input_image, dict) and 'composite' in input_image:
            image = np.array(input_image['composite'])
            logging.debug(f"Received image shape: {image.shape}")
        else:
            raise ValueError("Input image is not in the expected format.")
        
        # Convert grayscale to RGB if necessary
        if image.ndim == 2:  # If it's a grayscale image
            image = np.stack((image,) * 3, axis=-1)  # Convert to RGB
            
        logging.debug(f"Image after conversion to RGB: {image.shape}")


        # Convert to PIL Image
        image = Image.fromarray(image.astype('uint8')).convert("RGB")
        logging.debug(f"PIL image size: {image.size}")

        # Transform the image
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        tensor = transform(image)
        logging.debug(f"Transformed tensor shape: {tensor.shape}")

        tensor = tensor.unsqueeze(0)  # Add batch dimension
        return tensor

    except Exception as e:
        logging.error(f"Error in preprocessing: {e}")
        raise

def predict(input_image):
    try:
        logging.debug("Predict function started.")
        processed_image = preprocess_image(input_image)

        processed_image = processed_image.to(device)

        with torch.no_grad():
            outputs = model(processed_image)
            probabilities = F.softmax(outputs, dim=1)[0]

            logging.debug(f"Raw outputs: {outputs}")
            logging.debug(f"Probabilities: {probabilities}")

            predictions = {
                str(i): round(float(probabilities[i]) , 2)
                for i in range(10)
            }

            logging.debug(f"Predictions: {predictions}")

            # Return the original image and predictions
            return input_image['composite'], predictions

    except Exception as e:
        logging.error(f"Error in prediction: {e}")
        return {str(i): 0.0 for i in range(10)}

model, device = load_model()
interface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(crop_size=(32, 32), type='numpy', image_mode='L', brush=gr.Brush()),
    outputs=[
        gr.Image(type="numpy", label="Input Image"),
        gr.Label(num_top_classes=10)
    ],
    title="Digit Recognition with VGG Network",
    description="Draw a digit (0-9) and the model will predict what digit it is.",
    article="The model will show confidence scores for all digits (0-9).",
    examples=[],
    cache_examples=False,
    theme=gr.themes.Default()
)
if __name__ == "__main__":
    interface.launch(share=False)

    
    

  checkpoint = torch.load('./models/best_model.pth', map_location=device)


* Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.
