In [1]:
import torchvision.models as models
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
from dataset import IDRiD_Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
class MTL(nn.Module):
    def __init__(self):
        super(MTL, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.features = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
        self.last = nn.Sequential(nn.Linear(2048, 1024),nn.ReLU())
        self.retinopathy_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.macular_edema_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.fovea_center_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))
        self.optical_disk_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))

        
    def forward(self,data):
        out = self.features.forward(data).squeeze()
        out = self.last.forward(out)
        return (self.retinopathy_classifier(out),
                self.macular_edema_classifier(out),
                self.fovea_center_cords(out),
                self.optical_disk_cords(out))


In [4]:
class MTL_trainer():
    def __init__(self,mtl,optimizer, scheduler, criterion, tasks, epochs):
        self.mtl = mtl
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.criterion2 = nn.MSELoss()
        self.tasks = tasks
        self.epochs = epochs

    def train(self,train_dl):
        for e in range(self.epochs):
            self.mtl.train()
            print("Epoch: {} ".format(e+1),end="")
            self.train_iter(train_dl)
            
    def train_iter(self, train_dl):
        train_loss = 0.0
        loss0sum = 0.0
        loss1sum = 0.0
        loss2sum = 0.0
        loss3sum = 0.0
        loss = torch.tensor(0)
        accuracy0 = 0.0
        accuracy1 = 0.0
        for i, (imgs, retinopathy_label, macular_edema_label, fovea_center_labels, optical_disk_labels) in enumerate(train_dl):
            fovea_center_labels[:0], fovea_center_labels[:1] = fovea_center_labels[:0]*Rx, fovea_center_labels[:1]*Ry
            optical_disk_labels[:0], optical_disk_labels[:1] = optical_disk_labels[:0]*Rx, optical_disk_labels[:1]*Ry
            batch_size = imgs.size(0)
            self.optimizer.zero_grad()
            retinopathy_pred, macular_edema_pred, fovea_center_pred, optical_disk_pred = self.mtl(imgs.to(device))
            loss0 = self.criterion(retinopathy_pred, retinopathy_label.to(device).to(torch.int64)).to(torch.float64)*10
            loss1 = self.criterion(macular_edema_pred, macular_edema_label.to(device).to(torch.int64)).to(torch.float64)*10
            loss2 = torch.sqrt(self.criterion2(fovea_center_pred.to(torch.double),fovea_center_labels.to(device).to(torch.double)))/100
            loss3 = torch.sqrt(self.criterion2(optical_disk_pred.to(torch.double),optical_disk_labels.to(device).to(torch.double)))/100
            loss0sum += loss0
            loss1sum += loss1
            loss2sum += loss2
            loss3sum += loss3
            pred0 = F.softmax(retinopathy_pred, dim = -1).argmax(dim=-1)
            accuracy0 += pred0.eq(retinopathy_label.to(device)).sum().item()
            pred1 = F.softmax(macular_edema_pred, dim = -1).argmax(dim=-1)
            accuracy1 += pred1.eq(macular_edema_label.to(device)).sum().item()
            loss = torch.stack((loss0, loss1, loss2 ,loss3))[self.tasks].sum()
            if i%4==0:
              print('=',end="")
            #print('Batch number: {}\nLoss on batch: {}\nLoss0: {}\nLoss1: {}\nLoss2: {}\nLoss3: {}\n-----------------------'.format(i,loss.item(), loss0.item(), loss1.item(), loss2.item() ,loss3.item()))
            loss.backward()
            self.optimizer.step()
            train_loss += loss
        print("\nTotal Loss: {}\nLoss0: {}  Accuracy0: {}\nLoss1: {}  Accuracy1: {}\nLoss2: {}\nLoss3: {}".format(train_loss, loss0sum, accuracy0*100/413, loss1sum, accuracy1*100/413,loss2sum,loss3sum))
        return train_loss

In [5]:
rx=1800
ry=1200
old_x=4288
old_y=2848
Rx=rx/old_x
Ry=ry/old_y

data_transformer = transforms.Compose([transforms.Resize((rx,ry)),transforms.ToTensor()])
train_ds = IDRiD_Dataset(data_transformer,'train')
train_dl = DataLoader(train_ds,batch_size=3,shuffle=True)
mtl = MTL()

if torch.cuda.is_available():
    device = torch.device("cuda")
    mtl=mtl.to(device)
    mtl=nn.DataParallel(mtl)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mtl.parameters(),
                                weight_decay=1e-6,
                                momentum=0.9,
                                lr=1e-3,
                                nesterov=True)
scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=3,
                                  min_lr=1e-7,
                                  verbose=True)
tasks=[[0,1,2,3]]

In [6]:
mtl_trainer=MTL_trainer(mtl,optimizer,scheduler,criterion,tasks,3)
mtl_trainer.train(train_dl)

Total Loss: 8413.86153288096
Loss0: 2180.370692014694  Accuracy0: 31.476997578692494
Loss1: 2055.119348168373  Accuracy1: 43.09927360774818
Loss2: 2090.250257522924
Loss3: 2088.1212351749696
Total Loss: 8424.023576428079
Loss0: 2165.13640999794  Accuracy0: 32.20338983050848
Loss1: 2038.5176730155945  Accuracy1: 42.857142857142854
Loss2: 2080.3062080747504
Loss3: 2140.0632853397938
Total Loss: 8273.948805949389
Loss0: 2107.0050621032715  Accuracy0: 32.445520581113804
Loss1: 1969.2409765720367  Accuracy1: 45.27845036319613
Loss2: 2088.4077139206242
Loss3: 2109.2950533534536
