<a href="https://colab.research.google.com/github/110805/Retinopathy_detection/blob/master/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/110805/Retinopathy_detection.git
%cd Retinopathy_detection/

Cloning into 'Retinopathy_detection'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects:   5% (1/18)[Kremote: Counting objects:  11% (2/18)[Kremote: Counting objects:  16% (3/18)[Kremote: Counting objects:  22% (4/18)[Kremote: Counting objects:  27% (5/18)[Kremote: Counting objects:  33% (6/18)[Kremote: Counting objects:  38% (7/18)[Kremote: Counting objects:  44% (8/18)[Kremote: Counting objects:  50% (9/18)[Kremote: Counting objects:  55% (10/18)[Kremote: Counting objects:  61% (11/18)[Kremote: Counting objects:  66% (12/18)[Kremote: Counting objects:  72% (13/18)[Kremote: Counting objects:  77% (14/18)[Kremote: Counting objects:  83% (15/18)[Kremote: Counting objects:  88% (16/18)[Kremote: Counting objects:  94% (17/18)[Kremote: Counting objects: 100% (18/18)[Kremote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 35153 (delta 9), reused 6 (delta 3), pack-reused 35135[

In [12]:
from dataloader import RetinopathyLoader
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models
import matplotlib.pyplot as plt
import copy

# Hyperparameter setting
batch_size = 4
learning_rate = 1e-3
epochs_18 = 10
epochs_50 = 5
momentum = 0.9
weight_decay = 5e-4
criterion = nn.CrossEntropyLoss()

train_data = RetinopathyLoader(root='/content/Retinopathy_detection/data/', mode='train')
test_data = RetinopathyLoader(root='/content/Retinopathy_detection/data/', mode='test')
train_loader = DataLoader(train_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

ResNet18_1 = torchvision.models.resnet18(pretrained=False)
ResNet18_1.fc = nn.Linear(512, 5)
ResNet18_2 = torchvision.models.resnet18(pretrained=True)
ResNet18_2.fc = nn.Linear(512, 5)
ResNet50_1 = torchvision.models.resnet50(pretrained=False)
ResNet50_1.fc = nn.Linear(2048, 5)
ResNet50_2 = torchvision.models.resnet50(pretrained=True)
ResNet50_2.fc = nn.Linear(2048, 5)

models_18 = [ResNet18_1, ResNet18_2]
models_50 = [ResNet50_1, ResNet50_2]
device = torch.device('cuda')

def train(model):
    model.train()
    correct = 0

    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels.flatten())
        _, preds = torch.max(outputs, 1) # the second return of max is the return of argmax
        loss.backward()
        optimizer.step()
        correct += torch.sum(preds == labels.data.flatten())
        if idx%1000 == 0:
            print(idx)
            
    epoch_acc = 100*correct.item() / 28099    
    print('Train Acc: {:4f}'.format(epoch_acc))

    return epoch_acc

def test(model, best_acc, best_model_weight):
    # i indicates that which model we are running now
    model.eval()
    correct = 0

    for idx, (inputs, labels) in enumerate(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = model(inputs.float())

        _, preds = torch.max(outputs, 1) # the second return of max is the return of argmax
        correct += torch.sum(preds == labels.data.flatten())
        if idx%1000 == 0:
            print(idx)
            
    epoch_acc = 100*correct.item() / 7025    
    print('Test Acc: {:4f}'.format(epoch_acc))

    if epoch_acc > best_acc:
        best_model_weight = copy.deepcopy(model.state_dict())

    return epoch_acc, best_model_weight

model_weight = []
legend = ['Train(w/o pretrain)', 'Test(w/o pretrain)', 'Test(w/o pretrain)', 'Test(with pretrain)']
for i, model in enumerate(models_18):
    best_acc = 0
    model.to(device)
    train_acc = []
    test_acc = []
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    weight = copy.deepcopy(model.state_dict())

    for epoch in range(epochs_18):
        print('Epoch {}'.format(epoch+1))
        train_acc.append(train(model))
        acc, weight = test(model, best_acc, weight)
        test_acc.append(acc)
        if acc > best_acc:
            best_acc = acc

        print('-' * 10)
    
    model_weight.append(weight)
    print('Best Acc: {:4f}'.format(best_acc))
    plt.plot(range(epochs), train_acc, label=legend[2*i])
    plt.plot(range(epochs), test_acc, label=legend[2*i+1])

plt.xlabel('Epochs')
plt.ylabel('Accuracy(%)')
plt.title("Result comparison(ResNet18)")
plt.legend(loc='best')
plt.savefig("Result_ResNet18.png")
plt.show()

for i, model in enumerate(models_50):
    best_acc = 0
    model.to(device)
    train_acc = []
    test_acc = []
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    weight = copy.deepcopy(model.state_dict())

    for epoch in range(epochs_50):
        print('Epoch {}'.format(epoch+1))
        train_acc.append(train(model))
        acc, weight = test(model, best_acc, weight)
        test_acc.append(acc)
        if acc > best_acc:
            best_acc = acc

        print('-' * 10)
    
    model_weight.append(weight)
    print('Best Acc: {:4f}'.format(best_acc))
    plt.plot(range(epochs), train_acc, label=legend[2*i])
    plt.plot(range(epochs), test_acc, label=legend[2*i+1])

plt.xlabel('Epochs')
plt.ylabel('Accuracy(%)')
plt.title("Result comparison(ResNet50)")
plt.legend(loc='best')
plt.savefig("Result_ResNet50.png")
plt.show()

> Found 28099 images...
> Found 7025 images...
Epoch 1
0


KeyboardInterrupt: ignored