In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install efficientnet_pytorch


In [None]:
import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
import torch.optim as optim

class EfficientNetB0(nn.Module):
    def __init__(self, num_classes=4):
        super(EfficientNetB0, self).__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b0')
        in_features = self.model._fc.in_features
        self.model._fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)

model = EfficientNetB0(num_classes=4)
model_path = '/content/drive/My Drive/models/RSCNN/Eff_Net.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
import torch
from torchvision import transforms
from PIL import Image
import ipywidgets as widgets
from IPython.display import display
import io

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class_to_idx = {0: 'Normal', 1: 'Diabetic Retinopathy', 2: 'Macular Degeneration', 3: 'Drusen'}
def on_upload_change(change):
    image_file = next(iter(uploaded_image.value.values()))
    image = Image.open(io.BytesIO(image_file['content'])).convert("RGB")
    input_image = transform(image).unsqueeze(0)

    input_image = input_image.to(device)

    model.eval()

    with torch.no_grad():
        outputs = model(input_image)

    _, predicted = torch.max(outputs, 1)
    class_index = predicted.item()

    predicted_class_name = class_to_idx.get(class_index, "Unknown")

    print(f'Predicted Class Index: {class_index}')
    print(f'Predicted Class Name: {predicted_class_name}')

uploaded_image = widgets.FileUpload(
    accept='image/*',
    multiple=False
)

uploaded_image.observe(on_upload_change, names='value')
display(uploaded_image)
