In [1]:
import sys
import numpy as np
import pandas as pd
import os
import cv2
import wandb
from datetime import datetime
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset


from hc701fed.dataset.EyePACS_and_APTOS import Eye_APTOS
from hc701fed.dataset.messidor import MESSIDOR

# Which GPU to use
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [2]:
Eye_APTOS_data_dir_options = {
    'EyePACS': '/home/xiangjianhou/hc701-fed/preprocessed/eyepacs',
    'APTOS': '/home/xiangjianhou/hc701-fed/preprocessed/aptos',
}
MESSIDOR_data_dir_options = {
    'messidor2': '/home/xiangjianhou/hc701-fed/preprocessed/messidor2',
    'messidor_pairs' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_pairs',
    'messidor_Etienne' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_Etienne',
    'messidor_Brest-without_dilation' : '/home/xiangjianhou/hc701-fed/preprocessed/messidor/messidor_Brest-without_dilation'
}

In [7]:
MESSIDOR_pairs_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_pairs'], train=True, transform=None)
MESSIDOR_Etienne_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Etienne'], train=True, transform=None)
MESSIDOR_Brest_train = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Brest-without_dilation'], train=True, transform=None)
MESSIDOR_Centerlized_train = ConcatDataset([MESSIDOR_pairs_train, MESSIDOR_Etienne_train,MESSIDOR_Brest_train])

MESSIDOR_pairs_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_pairs'], train=False, transform=None)
MESSIDOR_Etienne_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Etienne'], train=False, transform=None)
MESSIDOR_Brest_test = MESSIDOR(data_dir=MESSIDOR_data_dir_options['messidor_Brest-without_dilation'], train=False, transform=None)
MESSIDOR_Centerlized_test = ConcatDataset([MESSIDOR_pairs_test, MESSIDOR_Etienne_test,MESSIDOR_Brest_test])

In [8]:
MESSIDOR_Centerlized_train_loader = DataLoader(MESSIDOR_Centerlized_train, batch_size=32, shuffle=True, num_workers=0)
MESSIDOR_Centerlized_test_loader = DataLoader(MESSIDOR_Centerlized_test, batch_size=32, shuffle=True, num_workers=0)

In [24]:
from hc701fed.model.baseline import Baseline
model_demo = Baseline(backbone='densenet121',num_classes=4,pretrained=True)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_demo.parameters(), lr=0.001)
model_save_path = '/home/xiangjianhou/hc701-fed/checkpoint/MESSIDOR_3_hosptial_4class'

In [25]:
for epoch in range(100):
    model_demo.train()
    model_demo.to(device)
    for i, (images, labels) in enumerate(MESSIDOR_Centerlized_train_loader):
        images = images.to(device,torch.float32)
        labels = labels.to(device)
        outputs = model_demo(images)
        loss_value = loss(outputs, labels)
        optimizer.zero_grad()
        loss_value.backward()
        optimizer.step()
        print('epoch: {}, batch: {}, loss: {}'.format(epoch, i, loss_value.item()))
    if epoch % 1 == 0:
        torch.save(model_demo.state_dict(), os.path.join(model_save_path, 'model_{}.pth'.format(epoch)))

epoch: 0, batch: 0, loss: 1.44024658203125
epoch: 0, batch: 1, loss: 1.361745834350586
epoch: 0, batch: 2, loss: 1.026029348373413
epoch: 0, batch: 3, loss: 1.1484640836715698
epoch: 0, batch: 4, loss: 1.4625872373580933
epoch: 0, batch: 5, loss: 1.7400227785110474
epoch: 0, batch: 6, loss: 1.3524588346481323
epoch: 0, batch: 7, loss: 1.223073124885559
epoch: 0, batch: 8, loss: 1.291237235069275
epoch: 0, batch: 9, loss: 1.183354377746582
epoch: 0, batch: 10, loss: 1.2781695127487183
epoch: 0, batch: 11, loss: 1.1489821672439575
epoch: 0, batch: 12, loss: 1.2569248676300049
epoch: 0, batch: 13, loss: 1.255894422531128
epoch: 0, batch: 14, loss: 1.2994041442871094
epoch: 0, batch: 15, loss: 1.2678431272506714
epoch: 0, batch: 16, loss: 1.0039314031600952
epoch: 0, batch: 17, loss: 1.0960462093353271
epoch: 0, batch: 18, loss: 1.2995574474334717
epoch: 0, batch: 19, loss: 1.064724087715149
epoch: 0, batch: 20, loss: 1.1173396110534668
epoch: 0, batch: 21, loss: 1.2939146757125854
epoch: 

In [26]:
# test model accuracy and f1 score
from sklearn.metrics import f1_score
model_demo.eval()
model_demo.to(device)
y_true = []
y_pred = []
for i, (images, labels) in enumerate(MESSIDOR_Centerlized_test_loader):
    images = images.to(device,torch.float32)
    labels = labels.to(device)
    outputs = model_demo(images)
    outputs = torch.argmax(outputs, dim=1)
    y_true.append(labels.cpu().numpy())
    y_pred.append(outputs.cpu().numpy())
y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)
print('accuracy: {}'.format(np.mean(y_true == y_pred)))
print('f1 score: {}'.format(f1_score(y_true, y_pred, average='macro')))

accuracy: 0.6366666666666667
f1 score: 0.5353231339862357
