<a href="https://colab.research.google.com/github/PallaviUpreti/PallaviUpreti/blob/main/diabetes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
import io
import ipywidgets as widgets
from IPython.display import display
from google.colab import files


# Check if CUDA (GPU) is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained ResNet-152 model and modify its fully connected layers
model = models.resnet152(pretrained=False)
num_features = model.fc.in_features
num_classes = 5  # Replace with the number of classes in your task
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Linear(512, num_classes),
    nn.LogSoftmax(dim=1)
)

# Define loss function, optimizer, and learning rate scheduler
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
model.to(device)

# Freeze layers except 'layer2', 'layer3', 'layer4', and 'fc'
for name, child in model.named_children():
    if name in ['layer2', 'layer3', 'layer4', 'fc']:
        for param in child.parameters():
            param.requires_grad = True
    else:
        for param in child.parameters():
            param.requires_grad = False

# Adjust the learning rate and create a new optimizer and scheduler
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.000001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Function to load a saved model
def load_model(path):
    checkpoint = torch.load(path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model

# Function to perform image inference using the loaded model
def inference(model, image_file, transform, classes):
    image = Image.open(image_file).convert('RGB')
    img = transform(image).unsqueeze(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    with torch.no_grad():
        out = model(img.to(device))
        ps = torch.exp(out)
        top_p, top_class = ps.topk(1, dim=1)
        predicted_class = top_class.item()

    print("Predicted Severity Value:", predicted_class)
    print("Class:", classes[predicted_class])
    print('Displaying the image:')
    plt.imshow(np.array(image))
    plt.text(15, 40, "Presence of Diabetes: " + classes[predicted_class], fontsize=12,
             bbox=dict(facecolor='red', alpha=0.5))
    plt.text(15, 100, "Prediction Severity Value: " + str(predicted_class), fontsize=12,
             bbox=dict(facecolor='red', alpha=0.5))
    plt.show()

    return predicted_class, classes[predicted_class]

# Load the trained model (replace with your model path)
model_path = '/content/drive/MyDrive/project/diabetes_detection_ml/model_weights/classifier.pt'
model = load_model(model_path)
print("Model loaded successfully")

# Define class labels and image transformations
classes = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

# Function to handle form submission and perform inference
def perform_inference(b):
    name = name_input.value
    age = age_input.value

    uploaded = files.upload()
    if not uploaded:
        print("Please upload an image.")
        return

    for fn in uploaded.keys():
        image_path = io.BytesIO(uploaded[fn])

        print(f"Name: {name}")
        print(f"Age: {age}")

        predicted_class, class_name = inference(model, image_path, test_transforms, classes)

        result_output.value = f"Patient Name: {name}\nAge: {age}\nPredicted Class: {class_name}"



# Create widgets for input
name_input = widgets.Text(
    value='',
    placeholder='Enter patients full name',
    description='Name:'
)

age_input = widgets.IntText(
    value=0,
    description='Age:'
)

# Create a button for image upload

upload_button = widgets.FileUpload(

    accept='image/*',
    multiple=False,
    description='Upload Image'
)

submit_button = widgets.Button(
    description='upload image'
)

submit_button.on_click(perform_inference)

# Create a container for widgets
form_container = widgets.VBox([
    name_input,
    age_input,

    submit_button
])


# Create an output area for displaying results
result_output = widgets.Output()

# Display the form and results
display(form_container, result_output)


Collecting streamlit
  Downloading streamlit-1.27.2-py2.py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
Collecting validators<1,>=0.2 (from streamlit)
  Downloading validators-0.22.0-py3-none-any.whl (26 kB)
Collecting gitpython!=3.1.19,<4,>=3.0.7 (from streamlit)
  Downloading GitPython-3.1.37-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.0/190.0 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.8.1b0-py2.py3-none-any.whl (4.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m32.7 MB/s[0m eta [36m0:00:00[0m
Collecting watchdog>=2.1.5 (from streamlit)
  Downloading watchdog-3.0.0-py3-none-manylinux2014_x86_64.whl (82 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.1/82.1 kB[0m [31m7.9 MB/s[0m eta [36m0:00:0

VBox(children=(Text(value='', description='Name:', placeholder='Enter patients full name'), IntText(value=0, d…

Output()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
