In [None]:
import torchvision.models as models
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import copy
import numpy as np
from torch import nn
from dataset import IDRiD_Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class MTL(nn.Module):
    def __init__(self):
        super(MTL, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        self.features = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
        self.last = nn.Sequential(nn.Linear(2048, 1024),nn.ReLU())
        self.retinopathy_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5))
        self.macular_edema_classifier = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 5))
        self.fovea_center_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))
        self.optical_disk_cords = nn.Sequential(nn.Linear(1024, 512),nn.ReLU(), nn.Linear(512, 2))

        
    def forward(self, data):
        out = self.features.forward(data).squeeze()
        out = self.last.forward(out)
        return (self.retinopathy_classifier(out),
                self.macular_edema_classifier(out),
                self.fovea_center_cords(out),
                self.optical_disk_cords(out))

    def fit(self, train_dl, optimizer, scheduler, criterion, tasks, epochs):
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.criterion2 = nn.MSELoss()
        self.tasks = tasks
        self.epochs = epochs
        best_loss=float("Inf")
        for e in range(self.epochs):
            self.train()
            print("Epoch: {} ".format(e+1),end="")
            current_loss=self.fit_iter(train_dl)
            if current_loss.item() < best_loss:
                best_loss = current_loss.item()
                torch.save(self.state_dict(),'./drive/MyDrive/weights.pt')
                print("Saved best model weights!")

    def fit_iter(self, train_dl):
        train_loss = 0.0
        loss0sum = 0.0
        loss1sum = 0.0
        loss2sum = 0.0
        loss3sum = 0.0
        loss = torch.tensor(0)
        accuracy0 = 0.0
        accuracy1 = 0.0
        for i, (imgs, retinopathy_label, macular_edema_label, fovea_center_labels, optical_disk_labels) in enumerate(train_dl):
            fovea_center_labels[:0], fovea_center_labels[:1] = fovea_center_labels[:0]*Rx, fovea_center_labels[:1]*Ry
            optical_disk_labels[:0], optical_disk_labels[:1] = optical_disk_labels[:0]*Rx, optical_disk_labels[:1]*Ry
            batch_size = imgs.size(0)
            self.optimizer.zero_grad()
            retinopathy_pred, macular_edema_pred, fovea_center_pred, optical_disk_pred = self.forward(imgs.to(device))
            loss0 = self.criterion(retinopathy_pred, retinopathy_label.to(device).to(torch.int64)).to(torch.float64)*10
            loss1 = self.criterion(macular_edema_pred, macular_edema_label.to(device).to(torch.int64)).to(torch.float64)*10
            loss2 = torch.sqrt(self.criterion2(fovea_center_pred.to(torch.double),fovea_center_labels.to(device).to(torch.double)))/10
            loss3 = torch.sqrt(self.criterion2(optical_disk_pred.to(torch.double),optical_disk_labels.to(device).to(torch.double)))/10
            loss0sum += loss0
            loss1sum += loss1
            loss2sum += loss2
            loss3sum += loss3
            pred0 = F.softmax(retinopathy_pred, dim = -1).argmax(dim=-1)
            accuracy0 += pred0.eq(retinopathy_label.to(device)).sum().item()
            pred1 = F.softmax(macular_edema_pred, dim = -1).argmax(dim=-1)
            accuracy1 += pred1.eq(macular_edema_label.to(device)).sum().item()
            loss = torch.stack((loss0, loss1, loss2 ,loss3))[self.tasks].sum()
            if i%4==0:
              print('=',end="")
            #print('Batch number: {}\nLoss on batch: {}\nLoss0: {}\nLoss1: {}\nLoss2: {}\nLoss3: {}\n-----------------------'.format(i,loss.item(), loss0.item(), loss1.item(), loss2.item() ,loss3.item()))
            loss.backward()
            self.optimizer.step()
            train_loss += loss
        print("\nTotal Loss: {}\nLoss0: {}  Accuracy0: {}\nLoss1: {}  Accuracy1: {}\nLoss2: {}\nLoss3: {}".format(train_loss, loss0sum, accuracy0*100/413, loss1sum, accuracy1*100/413,loss2sum,loss3sum))
        return train_loss

In [None]:
rx=450
ry=300
old_x=4288
old_y=2848
Rx=rx/old_x
Ry=ry/old_y

data_transformer = transforms.Compose([transforms.Resize((rx,ry)),transforms.ToTensor(),
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = IDRiD_Dataset(data_transformer,'train')
train_dl = DataLoader(train_ds,batch_size=32,shuffle=True)
mtl = MTL()

if torch.cuda.is_available():
    device = torch.device("cuda")
    mtl=mtl.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mtl.parameters(),
                                weight_decay=1e-6,
                                momentum=0.9,
                                lr=1e-3,
                                nesterov=True)
scheduler = ReduceLROnPlateau(optimizer,
                                  factor=0.5,
                                  patience=3,
                                  min_lr=1e-7,
                                  verbose=True)
tasks=[[0, 1, 2]]

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [None]:
mtl.fit(train_dl,optimizer,scheduler,criterion,tasks,30000)

In [None]:
import pandas as pd
mtl=MTL()
mtl.load_state_dict(torch.load("./drive/MyDrive/weights.pt"))
mtl.eval()
data=pd.DataFrame()
z = []
for i, (imgs, retinopathy_label, macular_edema_label, fovea_center_labels, optical_disk_labels) in enumerate(train_ds):
    one_row = mtl.forward(imgs[None,:])
    z.append(torch.concat(one_row).detach().numpy())
data = pd.DataFrame(z)
data.to_csv('teacher_train_predictions.csv')