In [1]:
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
from dataset import IDRiD_Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau


In [2]:
class MTL(nn.Module):
    def __init__(self):
        super(MTL, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.features = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
        self.last = nn.Sequential(nn.Linear(2048, 1024),nn.ReLU())
        self.retinopathy_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.macular_edema_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.fovea_center_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))
        self.optical_disk_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))

    def forward(self,data):
        out = self.features.forward(data).squeeze()
        out = self.last.forward(out)
        return (self.retinopathy_classifier(out),
                self.macular_edema_classifier(out),
                self.fovea_center_cords(out),
                self.optical_disk_cords(out))


In [3]:
#50x25 just for implementing
data_transformer = transforms.Compose([transforms.Resize((500,250)),transforms.ToTensor()])
train_ds = IDRiD_Dataset(data_transformer,'train')
train_dl = DataLoader(train_ds,batch_size=2,shuffle=True)
mtl = MTL()

In [4]:
class MTL_trainer():
    def __init__(self,mtl,optimizer, scheduler, criterion, tasks, epochs):
        self.mtl = mtl
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.criterion2 = nn.MSELoss()
        self.tasks = tasks
        self.epochs = epochs

    def train(self,train_dl):
        for e in range(self.epochs):
            self.mtl.train()
            print(self.train_iter(train_dl).item())
            

    def train_iter(self, train_dl):
        train_loss = 0.0
        loss = torch.tensor(0)
        for i, (imgs, retinopathy_label, macular_edema_label, fovea_center_labels, optical_disk_labels) in enumerate(train_dl):
            batch_size = imgs.size(0)
            self.optimizer.zero_grad()
            retinopathy_pred, macular_edema_pred, fovea_center_pred, optical_disk_pred = self.mtl(imgs)
            loss0 = self.criterion(retinopathy_pred, retinopathy_label.to(torch.int64)).to(torch.float64)*100
            loss1 = self.criterion(macular_edema_pred, macular_edema_label.to(torch.int64)).to(torch.float64)*100
            loss2 = torch.sqrt(self.criterion2(fovea_center_pred.to(torch.double),fovea_center_labels.to(torch.double)))/10
            loss3 = torch.sqrt(self.criterion2(optical_disk_pred.to(torch.double),optical_disk_labels.to(torch.double)))/10
            loss = torch.stack((loss0, loss1, loss2 ,loss3))[self.tasks].sum()
            print('Batch number: {}\nLoss on batch: {}\nLoss0: {}\nLoss1: {}\nLoss2: {}\nLoss3: {}\n-----------------------'.format(i,loss.item(), loss0.item(), loss1.item(), loss2.item() ,loss3.item()))
            loss.backward()
            self.optimizer.step()
            train_loss += loss
        return train_loss

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mtl.parameters(),
                                weight_decay=1e-6,
                                momentum=0.9,
                                lr=1e-3,
                                nesterov=True)
scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=3,
                                  min_lr=1e-7,
                                  verbose=True)
tasks=[[0,1,2,3]]

In [74]:
mtl_trainer=MTL_trainer(mtl,optimizer,scheduler,criterion,tasks,2)
mtl_trainer.train(train_dl)

Batch number: 0
Loss on batch: 7.241564384506855
Loss0: 1.5996904373168945
Loss1: 1.6031171083450317
Loss2: 1.618451106433125
Loss3: 2.4203057324118036
-----------------------
Batch number: 1
Loss on batch: 6.905180373988713
Loss0: 1.6074074506759644
Loss1: 1.6019704341888428
Loss2: 1.8358105584608195
Loss3: 1.8599919306630874
-----------------------
Batch number: 2
Loss on batch: 6.8507506300268535
Loss0: 1.6038973331451416
Loss1: 1.6029586791992188
Loss2: 1.7924470261231553
Loss3: 1.8514475915593374
-----------------------
Batch number: 3
Loss on batch: 7.163377940470573
Loss0: 1.6049728393554688
Loss1: 1.6015524864196777
Loss2: 1.6077287260215134
Loss3: 2.3491238886739136
-----------------------
Batch number: 4
Loss on batch: 7.236367794161778
Loss0: 1.6024727821350098
Loss1: 1.5963181257247925
Loss2: 1.6711564361399855
Loss3: 2.3664204501619905
-----------------------
Batch number: 5
Loss on batch: 6.903582478417452
Loss0: 1.5949654579162598
Loss1: 1.5937516689300537
Loss2: 1.78779

KeyboardInterrupt: 