In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import glob
import seaborn as sns
%matplotlib inline
import os
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pylab as plt
import torch
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
path2data = '../input/challamd/Training400/'
device='cuda'
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install openpyxl

In [None]:
path2labels = os.path.join(path2data,'Fovea_location.xlsx')
labels = pd.read_excel(path2labels)
labels.head()

In [None]:
AorN = [imn[0] for imn in labels.imgName]
sns.scatterplot(x=labels.Fovea_X,y=labels.Fovea_Y,hue=AorN)

In [None]:
def load_image(labels_ds,id):
    label = labels_ds.values[id]
    image_prefix ='AMD' if label[1][0] == 'A' else 'Non-AMD'
    image = Image.open(os.path.join(path2data,image_prefix,label[1]))
    x = label[2]
    y = label[3]
    label = (x,y)
    return image,label
image,(x,y) = load_image(labels,5)


In [None]:
def Draw_image(image,x,y,w=300,h=300,line_width=1):
    draw = ImageDraw.Draw(image)
    draw.rectangle(((x-w/2,y-h/2),(x+w/2,y+h/2)),width=line_width,outline='green')
    return np.asarray(image)


In [None]:
import torchvision.transforms.functional as TF
def img_resize(image,labels,target_size=(256,256)):
    x,y = labels
    o_w,o_h = image.size
    t_w,t_h = target_size
    n_img = TF.resize(image,target_size)
    n_labels = (x*(t_w/o_w),y*(t_h/o_h))
    return n_img,n_labels

In [None]:
def horizontal_flip(image,labels):
    w,h = image.size
    x,y = labels
    image = TF.hflip(image)
    labels = (w-x,y)
    return image,labels

In [None]:
def scale_label(label,image_size):
    div = [ai/bi for ai,bi in zip(label,image_size)]
    return div

In [None]:
def rescale_label(label,image_size):
    div = [ai*bi for ai,bi in zip(label,image_size)]
    return div

In [None]:
def vertical_flip(image,labels):
    w,h = image.size
    x,y = labels
    image = TF.vflip(image)
    labels = x,w-y
    return image,labels

In [None]:
def translate(image,labels,max_translation=(0.2,0.2)):
    w,h = image.size
    x,y = labels
    max_t_w,max_t_h = max_translation
    trans_coff_x= np.random.rand()*2-1
    trans_coff_y= np.random.rand()*2-1
    x_t = int(trans_coff_x*max_t_w*w)
    y_t = int(trans_coff_y*max_t_h*h)
    image = TF.affine(image,translate=(x_t,y_t),angle=0,shear=0,scale=1)
    labels = (x+x_t,y+y_t)
    return image,labels


In [None]:
def transformer(image,label,params):
    image,label = img_resize(image,label,params['target_size'])
    if np.random.rand()<params['p_hflip']:
        image,label = horizontal_flip(image,label)
    if np.random.rand()<params['p_hflip']:
        image,label = horizontal_flip(image,label)
    if np.random.rand()<params['p_shift']:
        image,label = translate(image,label,params['max_translation'])
    label = scale_label(label,params['target_size'])
    return image,label

In [None]:
import random 
img, label=load_image(labels,123)
params={
    "target_size" : (256, 256), "p_hflip" : .0,
    "p_vflip" : .0,
    "p_shift" : 1.0, "max_translation": (0.2, 0.2),
}
img_t,label_t=transformer(img,label,params)
label_t = rescale_label(label_t,params['target_size'])
print(label_t)
image = Draw_image(img_t,*label_t,w=30,h=30,line_width=3)
plt.imshow(image)

In [None]:
ids = [1,232,]
for i,_id in enumerate(ids):
    image,(x,y) = vertical_flip(*img_resize(*load_image(labels,_id),target_size=(250,250)))
    plt.figure(figsize = (15,15))
    plt.subplot(3,3,i+1)
    img = Draw_image(image,x,y,w=50,h=50,line_width=1)
    plt.imshow(img)
    plt.title(labels.values[_id,1])
plt.show()

In [None]:
from torch.utils.data import Dataset
class AMD_Dataset(Dataset):
    def __init__(self,path2data,transform=0,transform_params=0):
        self.transformer = transformer
        self.transformer_params = transform_params
        path2labels = os.path.join(path2data,'Fovea_location.xlsx')
        labels_pd   = pd.read_excel(path2labels,index_col = 'ID')
        self.labels = labels_pd[['Fovea_X','Fovea_Y']].values
        self.ids    = labels_pd.index
        imgsName  = labels_pd['imgName']
        self.path2images = np.zeros(len(self.ids),dtype=object)
        for _id in self.ids:
            img_name = imgsName[_id]
            prefix = 'AMD' if img_name[0] =='A' else 'Non-AMD'
            image_path = os.path.join(path2data,prefix,img_name)
            self.path2images[_id-1] = image_path
    def __len__(self):
        return len(self.ids)
    def __getitem__(self,idx):
        image = Image.open(self.path2images[idx])
        label = self.labels[idx]
        image, label = self.transformer(image,label,self.transformer_params)
        return np.asarray(image).transpose(2,0,1),label
data = AMD_Dataset(path2data=path2data,transform_params={
    "target_size" : (256, 256), "p_hflip" : .0,
    "p_vflip" : .0,
    "p_shift" : 0, "max_translation": (0.2, 0.2),
})

In [None]:
from torch.utils.data import Subset
from sklearn.model_selection import ShuffleSplit
ss = ShuffleSplit(n_splits=1,test_size=0.2,random_state=42)
indices = range(len(data))
for train_index,val_index in ss.split(indices):
    train_sub = Subset(data,train_index)
    val_sub = Subset(data,val_index)

Test dataset

In [None]:
from torch.utils.data import DataLoader
import torch
train_dl = DataLoader(train_sub,batch_size=8,shuffle=True)
for img,label in train_dl:
    print(img.shape)
    label = torch.stack(label,1)
    print(label)
    break

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self,params={'input_channels':3,'initial_filter':16,'number_out':2}):
        super(Net,self).__init__()
        C_in = params['input_channels']
        init_f = params['initial_filter']
        num_outputs = params['number_out']
        self.conv1 = nn.Conv2d(C_in,init_f,kernel_size=3,padding=1,stride=2)
        self.conv2 = nn.Conv2d(C_in+init_f,  2*init_f,kernel_size=3,padding=1,stride=1)
        self.conv3 = nn.Conv2d(C_in+init_f*3,4*init_f,kernel_size=3,padding=1,stride=1)
        self.conv4 = nn.Conv2d(C_in+init_f*7,8*init_f,kernel_size=3,padding=1,stride=1)
        self.conv5 = nn.Conv2d(C_in+init_f*15,16*init_f,kernel_size=3,padding=1,stride=1)
        self.fcl   = nn.Linear(16*init_f,num_outputs)

    def forward(self,x):
        identity = F.avg_pool2d(x,kernel_size = 4)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x,2)
        x = torch.cat((x,identity),dim=1)
        
        identity = F.avg_pool2d(x,kernel_size = 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x,2)
        x = torch.cat((x,identity),dim=1)
        
        identity = F.avg_pool2d(x,kernel_size = 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x,2)
        x = torch.cat((x,identity),dim=1)
        
        identity = F.avg_pool2d(x,kernel_size = 2)
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x,2)
        x = torch.cat((x,identity),dim=1)
        
        x = F.relu(self.conv5(x))
        
        x = F.adaptive_avg_pool2d(x,1)
        x = torch.flatten(x,start_dim=1)
        x = self.fcl(x)
        return x
model = Net()


In [None]:
from torch import optim
optimizer = optim.Adam(model.parameters(),lr=3e-4)
def get_lr(opt):
    return opt.param_groups[0]['lr']
current_lr = get_lr(optimizer)
print('current lr = {}'.format(current_lr))


In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer,mode='min',factor=0.5,patience=20,verbose=1)


In [None]:
loss_func =nn.SmoothL1Loss(reduction='sum')

In [None]:
def cxcy2box(cxcy,w=50./256,h=50./256):
    w_tensor = (torch.ones(cxcy.shape[0],1)*w).to(device=device)
    h_tensor = (torch.ones(cxcy.shape[0],1)*h).to(device=device)
    w_h = torch.cat((w_tensor,h_tensor),1)
    X_Y_min = cxcy-w_h/2
    X_Y_max = cxcy+w_h/2
    return torch.cat((X_Y_min,X_Y_max),1)

In [None]:
import torchvision
def metrics_batch(output,target):
    output = cxcy2box(output)
    target = cxcy2box(target)
    iou=torchvision.ops.box_iou(output, target)
    return torch.diagonal(iou, 0).sum().item()

In [None]:
def loss_batch(loss_func, output, target, opt=None):
    loss = loss_func(output,target)
    with torch.no_grad():
        metric = metrics_batch(output,target)
    if opt is not None:
        opt.zero_grad()
        loss.backward()
        opt.step()
    return loss.item(),metric


In [None]:
def loss_epoch(model,loss_func,train_dl,val_dl=None,opt=None):
    running_loss = 0
    running_metric =0
    len_data = len(train_dl)
    c=0
    for x,y in train_dl:
        c=c+1
        y = torch.stack(y,1).to(device=device)
        output = model(x.double().to(device=device))
        loss,metric = loss_batch(loss_func,output,y,opt)
        running_loss += loss
        #metric += metric
    running_loss = running_loss/float(len_data)
    running_metric = running_metric/float(len_data)
    return loss , metric

In [None]:
#### loss_func=nn.SmoothL1Loss(reduction="sum")
opt = optim.Adam(model.parameters(), lr=1e-5)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)
model = model.double()
model.cuda()
loss_history= []
for i in range(40):
    print(i)
    GPUtil.showUtilization()
    loss,metric = loss_epoch(model,loss_func,train_dl,opt=opt)
    loss_history.append(loss)
    print(loss)
# train and validate the model
#model,loss_hist,metric_hist=train_val(model,params_train)

In [None]:
print(loss_history)


In [None]:
imgs =0
for imgs,label in train_dl:
    label = torch.stack(label,1)
    predic = model(imgs.double().to(device))
    print(predic*256)
    print(label*256)
    break
