In [1]:
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from torch import nn
import torch.nn.functional as F
from dataset import IDRiD_Dataset
from torch.utils.data import DataLoader


In [54]:
class MTL(nn.Module):
    def __init__(self):
        super(MTL, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.features = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
        self.last = nn.Sequential(nn.Linear(2048, 1024),nn.ReLU())
        self.retinopathy_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.macular_edema_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5), nn.Softmax(dim=1))
        self.fovea_center_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))
        self.optical_disk_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))

    def forward(self,data):
        out = self.features.forward(data).squeeze()
        out = self.last.forward(out)
        return (self.retinopathy_classifier(out),
                self.macular_edema_classifier(out),
                self.fovea_center_cords(out),
                self.optical_disk_cords(out))


In [62]:
data_transformer=transforms.Compose([transforms.Resize((2000,1000)),transforms.ToTensor()])
train_ds=IDRiD_Dataset(data_transformer,'train')
train_dl=DataLoader(train_ds,batch_size=2,shuffle=True)
mtl=MTL()

In [63]:
for img,label in train_dl:
    print(mtl.forward(img))
    break

(tensor([[0.2054, 0.2190, 0.1901, 0.1827, 0.2028],
        [0.2078, 0.2176, 0.1894, 0.1845, 0.2007]], grad_fn=<SoftmaxBackward>), tensor([[0.2083, 0.1877, 0.2043, 0.2041, 0.1955],
        [0.2064, 0.1896, 0.2011, 0.2072, 0.1957]], grad_fn=<SoftmaxBackward>), tensor([[0.0430, 0.1557],
        [0.0636, 0.1648]], grad_fn=<AddmmBackward>), tensor([[0.0242, 0.0697],
        [0.0397, 0.0663]], grad_fn=<AddmmBackward>))
