In [2]:
pip install torch torchvision datasets tqdm

Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from torch import nn, optim
from torch.optim import lr_scheduler
from tqdm import tqdm
from torchvision.models import resnet50
from torch.utils.data import Dataset
from PIL import Image

# Load the dataset from HuggingFace
ds = load_dataset("majorSeaweed/Diabetic_retinopathy_images")
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 115241
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 14201
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 14227
    })
})


In [4]:
print(ds['train'][0])

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=600x600 at 0x220956C07D0>, 'label': 0}


In [5]:
# Preprocessing and transforming the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
class DiabeticRetinopathyDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset['train'])

    def __getitem__(self, idx):
        # Accessing the image and label for the current index
        img = self.dataset['train'][idx]['image']  # image is already a PIL image
        label = self.dataset['train'][idx]['label']

        # Apply any transformations, if specified
        if self.transform:
            img = self.transform(img)
        
        return img, label


In [7]:
# Create DataLoader
train_dataset = DiabeticRetinopathyDataset(ds, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [8]:
# Initialize the ResNet50 model with pre-trained weights
model = resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 5)  # Adjust based on the number of classes in your dataset (e.g., 5 for 5 levels of retinopathy)



In [9]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# Training the model
num_epochs = 10  # Adjust based on your training time
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = correct / total * 100
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%")


100%|████████████████████████████████████████████████████████████████████████████| 3602/3602 [8:33:07<00:00,  8.55s/it]


Epoch 1/10, Loss: 0.9200, Accuracy: 64.59%


100%|███████████████████████████████████████████████████████████████████████████| 3602/3602 [16:12:56<00:00, 16.21s/it]


Epoch 2/10, Loss: 0.7917, Accuracy: 70.20%


100%|███████████████████████████████████████████████████████████████████████████| 3602/3602 [16:57:11<00:00, 16.94s/it]


Epoch 3/10, Loss: 0.7312, Accuracy: 72.69%


100%|███████████████████████████████████████████████████████████████████████████| 3602/3602 [22:26:46<00:00, 22.43s/it]


Epoch 4/10, Loss: 0.6855, Accuracy: 74.35%


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [6:49:54<00:00,  6.83s/it]


Epoch 5/10, Loss: 0.6416, Accuracy: 76.12%


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [7:50:24<00:00,  7.84s/it]


Epoch 6/10, Loss: 0.5944, Accuracy: 77.93%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [12:35:37<00:00, 12.59s/it]


Epoch 7/10, Loss: 0.5377, Accuracy: 79.99%


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [7:01:21<00:00,  7.02s/it]


Epoch 8/10, Loss: 0.4785, Accuracy: 82.31%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [18:30:12<00:00, 18.49s/it]


Epoch 9/10, Loss: 0.4165, Accuracy: 84.72%


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3602/3602 [7:44:48<00:00,  7.74s/it]

Epoch 10/10, Loss: 0.3611, Accuracy: 86.70%





In [12]:
# Save the trained model 
torch.save(model.state_dict(), "diabetic_retinopathy_resnet50.pth")