# Lab3 - Diabetic Retinopathy Detection

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm
import copy
import matplotlib.pyplot as plt

from dataloader import RetinopathyLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(planes)
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.block(x)
        out += self.shortcut(x)
        out = F.relu(out)
        
        return out

In [4]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_planes, planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(),
            nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.expansion*planes)
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
            
    def forward(self, x):
        out = self.block(x)
        out += self.shortcut(x)
        out = F.relu(out)
        
        return out

In [6]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=5, pretrained_model=None):
        super(ResNet, self).__init__()
        
        self.in_planes = 64

        if pretrained_model == None:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64)

            self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        else:
            pretrained_module = pretrained_model._modules
    
            self.conv1 = pretrained_module['conv1']
            self.bn1 = pretrained_module['bn1']
            
            self.layer1 = pretrained_module['layer1']
            self.layer2 = pretrained_module['layer2']
            self.layer3 = pretrained_module['layer3']
            self.layer4 = pretrained_module['layer4']
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        
        return out
    
    
def ResNet18(pretrain=False, num_classes=5):
    pretrained_model = None
    if pretrain:
        pretrained_model = torchvision.models.resnet18(pretrained=True)
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, pretrained_model)


def ResNet50(pretrain=False, num_classes=5):
    pretrained_model = None
    if pretrain:
        pretrained_model = torchvision.models.resnet50(pretrained=True)
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, pretrained_model)

In [7]:
def train_model(model, train_dataset, test_dataset, 
             batch_size=4, learning_rate=1e-3, num_epochs=10, optim=torch.optim.SGD, 
             momentum=0.9, weight_decay=5e-4, loss_function=torch.nn.CrossEntropyLoss):
    
    # Data loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
    
    # Define loss function and optimizer
    criterion = loss_function()
    optimizer = optim(model.parameters(), lr=learning_rate)
    
    # Train & Test
    best_acc = 0.0
    best_wts = None
    for epoch in range(num_epochs):
        # Train
        model.train()
        
        correct, total = 0, 0
        for x, y in tqdm(train_loader, total=len(train_loader)):
            
            x = x.to(device)
            y = y.to(device)
            
            # Forward
            outputs = model(x)
            loss = criterion(outputs, y)
            _, preds = torch.max(outputs.data, 1)
            correct += (preds==y).sum().item()
            total += y.size(0)
            
            # Back propagation & optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_acc = correct/total
            
        # Test
        model.eval()
        
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in tqdm(test_loader, total=len(test_loader)):
                
                x = x.to(device)
                y = y.to(device)
                
                # Forward
                outputs = model(x)
                _, preds = torch.max(outputs.data, 1)
                correct += (preds==y).sum().item()
                total += y.size(0)
        test_acc = correct/total
        
        # Save the best model
        if test_acc > best_acc:
            best_acc = test_acc
            best_wts = copy.deepcopy(model.state_dict())
        
        # Print stastics
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------------')
        print('Training acc: {}'.format(train_acc))
        print('Testing acc: {}'.format(test_acc))
        print()

        torch.cuda.empty_cache()
    
    model.load_state_dict(best_wts)
    return model