In [1]:
import sys
import numpy as np
import pandas as pd
import os
import cv2
import wandb
from datetime import datetime
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset

sys.path.append("../")
from dataset.EyePACS_and_APTOS import Eye_APTOS
from dataset.messidor import MESSIDOR

# Which GPU to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Change the following path to your own path, extract the data to the /tmp or /home folder is recommended, since load the data from /l is very slow

In [2]:
Eye_APTOS_data_dir_options = {
    'EyePACS': '/home/xiangjianhou/hc701-fed/preprocessed/eyepacs',
    'APTOS': '/home/xiangjianhou/hc701-fed/preprocessed/aptos',
}
MESSIDOR_data_dir_options = {
    'messidor2': '/home/xiangjianhou/hc701-fed/preprocessed/messidor2',
    'messidor_pairs' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_pairs',
    'messidor_Etienne' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_Etienne',
    'messidor_Brest-without_dilation' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_Brest-without_dilation'
}

In [3]:
# Dataset individual
APTOS_train = Eye_APTOS(data_dir=Eye_APTOS_data_dir_options['APTOS'], train=True, transform=None)
EyePACS_train = Eye_APTOS(data_dir=Eye_APTOS_data_dir_options['EyePACS'], train=True, transform=None)
MESSIDOR_2_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor2'], train=True, transform=None)
MESSIDOR_pairs_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_pairs'], train=True, transform=None)
MESSIDOR_Etienne_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Etienne'], train=True, transform=None)
MESSIDOR_Brest_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Brest-without_dilation'], train=True, transform=None)

In [4]:
# Combine all dataset as centralized dataset
Centerlized_train = ConcatDataset([APTOS_train, EyePACS_train, MESSIDOR_2_train, MESSIDOR_pairs_train, MESSIDOR_Etienne_train,MESSIDOR_Brest_train])
MESSIDOR_Centerlized_train = ConcatDataset([MESSIDOR_pairs_train, MESSIDOR_Etienne_train,MESSIDOR_Brest_train])

In [5]:
APTOS_test = Eye_APTOS(data_dir=Eye_APTOS_data_dir_options['APTOS'], train=False, transform=None)
EyePACS_test = Eye_APTOS(data_dir=Eye_APTOS_data_dir_options['EyePACS'], train=False, transform=None)
MESSIDOR_2_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor2'], train=False, transform=None)
MESSIDOR_pairs_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_pairs'], train=False, transform=None)
MESSIDOR_Etienne_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Etienne'], train=False, transform=None)
MESSIDOR_Brest_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Brest-without_dilation'], train=False, transform=None)

In [6]:
Centerlized_test = ConcatDataset([APTOS_test, EyePACS_test, MESSIDOR_2_test, MESSIDOR_pairs_test, MESSIDOR_Etienne_test,MESSIDOR_Brest_test])
MESSIDOR_Centerlized_test = ConcatDataset([MESSIDOR_pairs_test, MESSIDOR_Etienne_test,MESSIDOR_Brest_test])

In [7]:
for img, label in tqdm(DataLoader(Centerlized_train, batch_size=1, shuffle=False)):
    print(img.shape, label)
    print(img.max(), img.min())
    break

  0%|          | 0/40996 [00:00<?, ?it/s]

torch.Size([1, 3, 224, 224]) tensor([0])
tensor(1., dtype=torch.float64) tensor(0., dtype=torch.float64)





### How to load model the model list in the `HC701-PROJECT/model_list.txt`

In [5]:
from model.baseline import Baseline

In [7]:
# Change the backbone to your choice
model_demo = Baseline(backbone='densenet121',num_classes=5,pretrained=True)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_demo.parameters(), lr=0.001)
model_save_path = 'type your path here'
for i in range(10):
    for img, label in tqdm(DataLoader(Centerlized_train, batch_size=1, shuffle=True)):
        img = img.to(device)
        label = label.to(device)
        output = model_demo(img)
        loss_value = loss(output, label)
        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()
    # Save model after each epoch
    torch.save(model_demo.state_dict(), model_save_path)
    print(loss_value)