In [None]:
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
from dataset import IDRiD_Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau


In [None]:
class MTL(nn.Module):
    def __init__(self):
        super(MTL, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.features = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
        self.last = nn.Sequential(nn.Linear(2048, 1024),nn.ReLU())
        self.retinopathy_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.macular_edema_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.fovea_center_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))
        self.optical_disk_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))

    def forward(self,data):
        out = self.features.forward(data).squeeze()
        out = self.last.forward(out)
        return (self.retinopathy_classifier(out),
                self.macular_edema_classifier(out),
                self.fovea_center_cords(out),
                self.optical_disk_cords(out))


In [None]:
#50x25 just for implementing
data_transformer = transforms.Compose([transforms.Resize((50,25)),transforms.ToTensor()])
train_ds = IDRiD_Dataset(data_transformer,'train')
train_dl = DataLoader(train_ds,batch_size=32,shuffle=True)
mtl = MTL()

In [None]:
class MTL_trainer():
    def __init__(self,mtl,optimizer, scheduler, criterion, tasks, epochs):
        self.mtl = mtl
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.tasks = tasks
        self.epochs = epochs

    def train(self,train_dl):
        for e in range(self.epochs):
            self.mtl.train()
            print(self.train_iter(train_dl))
            

    def train_iter(self, train_dl):
        train_loss = 0.0
        loss = torch.tensor(0)
        for i, (imgs, retinopathy_label, macular_edema_label, fovea_center_labels, optical_disk_labels) in enumerate(train_dl):
            batch_size = imgs.size(0)
            self.optimizer.zero_grad()
            retinopathy_pred, macular_edema_pred, fovea_center_pred, optical_disk_pred = self.mtl(imgs)
            loss0 = self.criterion(retinopathy_pred, retinopathy_label.to(torch.int64))
            loss1 = self.criterion(macular_edema_pred, macular_edema_label.to(torch.int64))
            # loss2 = ???????????????????
            # loss3 = ???????????????????
            loss = torch.stack((loss0 ,loss1))[self.tasks].sum()
            print(loss)
            loss.backward()
            self.optimizer.step()
            train_loss += loss
        return train_loss

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mtl.parameters(),
                                weight_decay=1e-6,
                                momentum=0.9,
                                lr=1e-3,
                                nesterov=True)
scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=3,
                                  min_lr=1e-7,
                                  verbose=True)
tasks=[[0,1]]

In [None]:
mtl_trainer=MTL_trainer(mtl,optimizer,scheduler,criterion,tasks,2)
mtl_trainer.train(train_dl)