# Bira-net: Bilinear attention net for diabetic retinopathy grading

In our pursuit to develop a network inspired by successful models in DR grading, we focused on leveraging and adapting the BiRA-Net architecture, which was specifically designed for this purpose. BiRA-Net combines an attention model for feature extraction and bilinear model for fine-grained classification. In our adaptation, we modified the standard BiRA-Net by replacing its ResNet component with EfficientNet for improved feature extraction.

Paper:
https://arxiv.org/abs/1905.06312

Github: 
https://github.com/ISS-Kerui/BIRA-NET-BILINEAR-ATTENTION-NET-FOR-DIABETIC-RETINOPATHY-GRADING

**Note:** We highly recommend running this notebook on a GPU. 

## 0. Initialization

In [1]:
import os

os.chdir("..")
import requests
import zipfile
from tqdm import tqdm
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from src.models.bira_net import BiraNet
from src.utils import seed_everything
from src.loading import load_data, load_test_data
from src.train import train

In [2]:
# Set seeds
seed_everything()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



---

<a name='s1'></a>
## 1. Downloading the dataset

Fetching the dataset should take around 4-5 minutes. Unzipping takes 20s.

In [3]:
# if folder 'data/' is does not exist, download the data
if not os.path.exists("data/"):
    # Dropbox URL
    dropbox_url = "https://www.dropbox.com/scl/fi/sa14unf8s47e9ym125zgo/data.zip?rlkey=198bg0cmbmmrcjkfufy9064wm&dl=1"

    # File path where the .zip file will be saved
    file_path = "data.zip"

    response = requests.get(dropbox_url)

    if response.status_code == 200:
        with open(file_path, "wb") as file:
            file.write(response.content)
        message = "Download successful. The file has been saved as 'data.zip'."
    else:
        message = "Failed to download the file. Error code: " + str(
            response.status_code
        )

    print(message)

    # Path to the downloaded .zip file
    zip_file_path = "data.zip"

    # Directory to extract the contents of the zip file
    extraction_path = ""

    # Unzipping the file
    with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
        zip_ref.extractall(extraction_path)

    extraction_message = (
        f"The contents of the zip file have been extracted to: {extraction_path}"
    )

    print(extraction_message)

## 2. Data

Load the data with improved preprocessing. Set the batch size according to your machine, here we tried to set it as high as possible, as long as the GPU has enough memory.

In [4]:
# Load the .jpeg files in the data folder
PATH_IMAGES = "data/images_keep_ar"
PATH_LABELS = "data/labels/trainLabels.csv"
batch_size = 16
img_size = (400, 400)
num_epochs = 20
num_classes = 5

Load train and validation, 90-10 ratio. 

In [5]:
train_loader, validation_loader = load_data(
    PATH_LABELS, PATH_IMAGES, img_size, batch_size
)

Number of train samples:31964
Number of validation samples:3162


## 3. Model

Load the BiraNet model with efficientnet-b3 as backbone. The backbone is our pretrained model on diabetic retinopathy dataset.

In [6]:
# Initialize model
model = BiraNet(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)
model = model.to(device)

Optional. If you want to fine tune this model, we advise you to do in in steps. Load the last checkpoint to continue with training. 

In [7]:
# MODEL = "results/models/bira_baseline.pt"
# model.load_state_dict(torch.load(MODEL))

### 3.1 Model training and evaluation

In [8]:
train(
    model,
    train_loader,
    validation_loader,
    criterion,
    optimizer,
    device,
    model_name="results/models/bira_eff_net.pt",
    num_epochs=num_epochs,
)

100%|██████████| 1998/1998 [08:19<00:00,  4.00it/s]


Epoch: 1 Loss: 0.11222338676452637
Train accuracy:  0.8451382805656363
Train kappa score:  0.7683582384254026
---------------


100%|██████████| 198/198 [00:45<00:00,  4.34it/s]


Validation accuracy:  0.8184693232131562
Validation kappa score:  0.6970448963530279
---------------
New best model saved with kappa score: 0.6970448963530279


100%|██████████| 1998/1998 [08:30<00:00,  3.92it/s]


Epoch: 2 Loss: 0.23008887469768524
Train accuracy:  0.8549618320610687
Train kappa score:  0.8051324044687467
---------------


100%|██████████| 198/198 [00:45<00:00,  4.31it/s]


Validation accuracy:  0.8118279569892473
Validation kappa score:  0.6841168367727766
---------------


100%|██████████| 1998/1998 [08:27<00:00,  3.94it/s]


Epoch: 3 Loss: 0.5058364868164062
Train accuracy:  0.8541484169690903
Train kappa score:  0.805538260603288
---------------


100%|██████████| 198/198 [00:45<00:00,  4.38it/s]


Validation accuracy:  0.8121442125237192
Validation kappa score:  0.6879570270300579
---------------


100%|██████████| 1998/1998 [08:25<00:00,  3.96it/s]


Epoch: 4 Loss: 0.8341467380523682
Train accuracy:  0.8553059692153673
Train kappa score:  0.805776257090314
---------------


100%|██████████| 198/198 [00:45<00:00,  4.35it/s]


Validation accuracy:  0.8162555344718533
Validation kappa score:  0.6933094000944733
---------------


100%|██████████| 1998/1998 [08:25<00:00,  3.96it/s]


Epoch: 5 Loss: 0.611422598361969
Train accuracy:  0.856619947440871
Train kappa score:  0.8102780956895763
---------------


100%|██████████| 198/198 [00:44<00:00,  4.45it/s]


Validation accuracy:  0.8149905123339658
Validation kappa score:  0.6971858014924694
---------------
New best model saved with kappa score: 0.6971858014924694


100%|██████████| 1998/1998 [08:11<00:00,  4.06it/s]


Epoch: 6 Loss: 0.19261349737644196
Train accuracy:  0.8553998248029032
Train kappa score:  0.8089816532509806
---------------


100%|██████████| 198/198 [00:42<00:00,  4.66it/s]


Validation accuracy:  0.8111954459203036
Validation kappa score:  0.6933176393647454
---------------


100%|██████████| 1998/1998 [08:02<00:00,  4.14it/s]


Epoch: 7 Loss: 0.11881455034017563
Train accuracy:  0.856807658615943
Train kappa score:  0.8103601298045779
---------------


100%|██████████| 198/198 [00:43<00:00,  4.59it/s]


Validation accuracy:  0.8156230234029096
Validation kappa score:  0.7046498000135089
---------------
New best model saved with kappa score: 0.7046498000135089


100%|██████████| 1998/1998 [07:59<00:00,  4.17it/s]


Epoch: 8 Loss: 0.33350062370300293
Train accuracy:  0.8555875359779752
Train kappa score:  0.809853948633956
---------------


100%|██████████| 198/198 [00:42<00:00,  4.66it/s]


Validation accuracy:  0.8162555344718533
Validation kappa score:  0.6997732985646158
---------------


100%|██████████| 1998/1998 [07:59<00:00,  4.16it/s]


Epoch: 9 Loss: 0.27393218874931335
Train accuracy:  0.8581842072331373
Train kappa score:  0.8123947144490061
---------------


100%|██████████| 198/198 [00:43<00:00,  4.52it/s]


Validation accuracy:  0.8149905123339658
Validation kappa score:  0.7026566772506675
---------------


100%|██████████| 1998/1998 [07:59<00:00,  4.17it/s]


Epoch: 10 Loss: 0.3062385618686676
Train accuracy:  0.8564635214616444
Train kappa score:  0.8080269572121669
---------------


100%|██████████| 198/198 [00:42<00:00,  4.63it/s]


Validation accuracy:  0.8134092346616065
Validation kappa score:  0.6877432718569902
---------------


100%|██████████| 1998/1998 [08:01<00:00,  4.15it/s]


Epoch: 11 Loss: 0.5665830373764038
Train accuracy:  0.8559629583281191
Train kappa score:  0.8113499671028293
---------------


100%|██████████| 198/198 [00:43<00:00,  4.58it/s]


Validation accuracy:  0.8172043010752689
Validation kappa score:  0.6995009288734958
---------------


100%|██████████| 1998/1998 [08:00<00:00,  4.15it/s]


Epoch: 12 Loss: 0.20597444474697113
Train accuracy:  0.8568702290076335
Train kappa score:  0.8098269802399449
---------------


100%|██████████| 198/198 [00:42<00:00,  4.62it/s]


Validation accuracy:  0.8137254901960784
Validation kappa score:  0.6960922562667555
---------------


100%|██████████| 1998/1998 [08:00<00:00,  4.15it/s]


Epoch: 13 Loss: 0.5816908478736877
Train accuracy:  0.8582467776248279
Train kappa score:  0.814716093428898
---------------


100%|██████████| 198/198 [00:42<00:00,  4.66it/s]


Validation accuracy:  0.8162555344718533
Validation kappa score:  0.6969416409642106
---------------


100%|██████████| 1998/1998 [08:01<00:00,  4.15it/s]


Epoch: 14 Loss: 0.7200804352760315
Train accuracy:  0.8567763734200976
Train kappa score:  0.8101227040780041
---------------


100%|██████████| 198/198 [00:42<00:00,  4.62it/s]


Validation accuracy:  0.8140417457305503
Validation kappa score:  0.6921033199889333
---------------


100%|██████████| 1998/1998 [07:58<00:00,  4.17it/s]


Epoch: 15 Loss: 0.11357419937849045
Train accuracy:  0.8574646477286948
Train kappa score:  0.8142847226171238
---------------


100%|██████████| 198/198 [00:42<00:00,  4.65it/s]


Validation accuracy:  0.812460468058191
Validation kappa score:  0.6957350914276721
---------------


100%|██████████| 1998/1998 [08:00<00:00,  4.16it/s]


Epoch: 16 Loss: 0.29996195435523987
Train accuracy:  0.8579964960580653
Train kappa score:  0.8104105589513849
---------------


100%|██████████| 198/198 [00:42<00:00,  4.67it/s]


Validation accuracy:  0.8153067678684377
Validation kappa score:  0.7001462939061952
---------------


100%|██████████| 1998/1998 [08:00<00:00,  4.16it/s]


Epoch: 17 Loss: 0.37165117263793945
Train accuracy:  0.8560255287198097
Train kappa score:  0.8097458250630456
---------------


100%|██████████| 198/198 [00:43<00:00,  4.58it/s]


Validation accuracy:  0.8172043010752689
Validation kappa score:  0.6928594955804255
---------------


100%|██████████| 1998/1998 [08:01<00:00,  4.15it/s]


Epoch: 18 Loss: 0.42793428897857666
Train accuracy:  0.8567763734200976
Train kappa score:  0.8108987599744321
---------------


100%|██████████| 198/198 [00:43<00:00,  4.58it/s]


Validation accuracy:  0.812460468058191
Validation kappa score:  0.6915336290418173
---------------


100%|██████████| 1998/1998 [08:02<00:00,  4.14it/s]


Epoch: 19 Loss: 0.37102627754211426
Train accuracy:  0.8578400700788387
Train kappa score:  0.8128355498113287
---------------


100%|██████████| 198/198 [00:42<00:00,  4.65it/s]


Validation accuracy:  0.8159392789373814
Validation kappa score:  0.6983095301875841
---------------


100%|██████████| 1998/1998 [07:56<00:00,  4.19it/s]


Epoch: 20 Loss: 0.3165440857410431
Train accuracy:  0.859498185458641
Train kappa score:  0.8195471716253278
---------------


100%|██████████| 198/198 [00:43<00:00,  4.55it/s]

Validation accuracy:  0.8159392789373814
Validation kappa score:  0.6999211700552697
---------------





### Submission

The following code is for generating a submission file

In [9]:
# Model path
MODEL = "results/models/bira_eff_net.pt"

# Initialize model
model = BiraNet(num_classes)
model.load_state_dict(torch.load(MODEL))
model = model.to(device)

In [10]:
img_size = (400, 400)
batch_size = 16
test = load_test_data("data/test/", img_size, batch_size)

In [12]:
# Test loop
model.eval()
test_preds = []
test_names = []
with torch.no_grad():
    for images, names in tqdm(test):
        images = images.to(device)
        outputs = model(images)
        predicted = outputs.argmax(dim=1)
        test_preds.extend(predicted.cpu().numpy())
        test_names.extend(names)

100%|██████████| 3349/3349 [12:36<00:00,  4.42it/s]


In [15]:
# Make submission csv, first column "image" second columns "level"
submission = pd.DataFrame({"image": test_names, "level": test_preds})
submission.to_csv("submission.csv", index=False)