In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import sklearn
import torchvision
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import seaborn as sns
import glob
from pathlib import Path
torch.manual_seed(1)
np.random.seed(1)

In [None]:
train_labels = pd.read_csv('../input/retinal-disease-classification/Training_Set/Training_Set/RFMiD_Training_Labels.csv')

valid_labels = pd.read_csv('../input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv') 

test_labels = pd.read_csv('../input/retinal-disease-classification/Test_Set/Test_Set/RFMiD_Testing_Labels.csv')

In [None]:
train_files = []
valid_files = []
test_files = []
for file in os.listdir('../input/retinal-disease-classification/Training_Set/Training_Set/Training'):
    train_files.append(file)
    
for file in os.listdir('../input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/Validation'):
    valid_files.append(file)
    
for file in os.listdir('../input/retinal-disease-classification/Test_Set/Test_Set/Test'):
    test_files.append(file)

In [None]:
train_ids = []
for element in train_files:
    train_ids.append(element.split('.')[0])
    
valid_ids = []
for element in valid_files:
    valid_ids.append(element.split('.')[0])

test_ids = []
for element in test_files:
    test_ids.append(element.split('.')[0]) 

In [None]:
train_ids = pd.Series(train_ids, name='ids')
train_files = pd.Series(train_files, name='filenames')
train_files = pd.concat([train_ids, train_files], axis=1)

valid_ids = pd.Series(valid_ids, name='ids')
valid_files = pd.Series(valid_files, name='filenames')
valid_files = pd.concat([valid_ids, valid_files], axis=1)

test_ids = pd.Series(test_ids, name='ids')
test_files = pd.Series(test_files, name='filenames')
test_files = pd.concat([test_ids, test_files], axis=1)

In [None]:
train_files['ids'] = train_files['ids'].astype('int64')
valid_files['ids'] = valid_files['ids'].astype('int64')
test_files['ids'] = test_files['ids'].astype('int64')

In [None]:
train_df = pd.merge(train_labels, train_files, left_on='ID', right_on='ids')
valid_df = pd.merge(valid_labels, valid_files, left_on='ID', right_on='ids')
test_df = pd.merge(test_labels, test_files, left_on='ID', right_on='ids')
train_df

In [None]:
train_df.drop(columns='ids')
valid_df.drop(columns='ids')
test_df.drop(columns='ids')

In [None]:
train_df['full_file_paths'] = '../input/retinal-disease-classification/Training_Set/Training_Set/Training/' + train_df['filenames']
valid_df['full_file_paths'] = '../input/retinal-disease-classification/Evaluation_Set/Evaluation_Set/Validation/' + valid_df['filenames']
test_df['full_file_paths'] = '../input/retinal-disease-classification/Test_Set/Test_Set/Test/' + test_df['filenames']

In [None]:
class RetinalDisease(torch.utils.data.Dataset):
    def __init__(self, df=train_df, transform=transforms.Compose([transforms.ToTensor()])):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        length = len(self.df)
        return length
        
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx, -1]
        label = torch.tensor(train_df.iloc[idx, 1:-3])
        image = Image.open(img_path).convert('RGB')
        img = np.array(image)
        image = self.transform(image=img)["image"]
        return image, label

In [None]:
train_transforms = A.Compose([
    A.Resize(1424, 2144),
    A.HorizontalFlip(),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

test_transforms = A.Compose([
    A.Resize(1424, 2144),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [None]:
train_dataset = RetinalDisease(transform=train_transforms)
valid_dataset = RetinalDisease(df=valid_df, transform=test_transforms)
test_dataset = RetinalDisease(df=test_df, transform=test_transforms)

In [None]:
batch_size = 2
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10)

In [None]:
len(train_loader)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device      

In [None]:
model = torchvision.models.resnet50(pretrained=True)
model.fc = nn.Sequential(
               nn.Linear(2048, 46))

In [None]:
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, weight_decay=0.0005, lr=0.0001)

criterion = nn.BCEWithLogitsLoss(reduction='sum')

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=8, cooldown=10, verbose=True)

model = model.to(device)
criterion = criterion.to(device)

In [None]:
epochs = 100

total_train_loss = []
total_valid_loss = []
best_valid_loss = np.Inf

for epoch in range(epochs): 
    print('Epoch: ', epoch + 1)
    train_loss = []
    valid_loss = []
    train_correct = 0
    train_total = 0
    valid_correct = 0
    valid_total = 0
    for image, target in train_loader:
        model.train()
        image, target = image.to(device), target.to(device)
        output = model(image)
        target = target.float()
        loss = criterion(output, target)
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        
    for image, target in valid_loader:
        with torch.no_grad():
            model.eval()
            optimizer.zero_grad()
            image, target = image.to(device), target.to(device)
            output = model(image)
            target = target.float()
            loss = criterion(output, target)
            valid_loss.append(loss.item())
            
    epoch_train_loss = np.mean(train_loss)
    epoch_valid_loss = np.mean(valid_loss)
    print(f'Epoch {epoch + 1}, train loss: {epoch_train_loss:.4f}, valid loss: {epoch_valid_loss:.4f}')
    if epoch_valid_loss < best_valid_loss:
        torch.save(model.state_dict(), 'retinal_disease.pt')
        print('Model improved. Saving model.')
        best_valid_loss = epoch_valid_loss
        
    lr_scheduler.step(epoch_valid_loss)
    total_train_loss.append(epoch_train_loss)
    total_valid_loss.append(epoch_valid_loss)