In [1]:
import torch
import os
from Confuse import ConfuseMatrix
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.io import read_image
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from sys import stdout
import time
import pandas as pd
# import numpy as np

In [2]:
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class img_normalize():

    def __init__(self):
        pass


    def __call__(self, img):
        if img.dtype == torch.uint8:
            img = img.to(torch.float32)
            img /= 255
        elif (img.dtype == np.uint8) or (img.dtype == np.uint16):
            img = img.astype('float32')
            img /= 255
        return img
    


class OcularDataset(Dataset):

    def __init__(self, datadir, csv, transforms = None):
        self.csv = csv
        self.datadir = datadir
        self.transform = transforms


    def __len__(self):
        return len(self.csv)
    

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.datadir,
                                self.csv.iloc[idx, -1])
        image = read_image(img_name)
        label = self.csv.iloc[idx, -2]

        if self.transform:
            image = self.transform(image)
        return image, label

In [4]:
torch.manual_seed(3)
# device = 'cuda'

## data loading
datadir = 'ocular-disease-recognition-odir5k/ODIR-5K/ODIR-5K/Training Images'
csvdir = 'ocular-disease-recognition-odir5k/full_df.csv'
csv = pd.read_csv(csvdir)
temp = ([np.array(eval(csv.iloc[i, -2]), dtype = 'float32') for i in range(len(csv))])
csv.target = temp
batchsize = 100
transform = transforms.Compose([transforms.Resize((224, 224)), img_normalize()])
dataset = OcularDataset(datadir, csv, transform)
train_dataset, valid_dataset, test_dataset = random_split(dataset, [.7, .2, .1])
num_workers = 6
train_dataloader = DataLoader(train_dataset, batch_size = batchsize, shuffle=False, num_workers= num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size = batchsize, shuffle=False, num_workers = 0, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size = batchsize, shuffle=False, num_workers = 0, pin_memory=True)

In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.patchify = nn.Unfold(kernel_size = patch_size, stride = patch_size)
        self.mlp = nn.Linear(in_channels * patch_size ** 2, embed_dim)
    

    def forward(self, x):
        x = self.patchify(x)
        x = x.transpose(-1, -2)
        return F.relu(self.mlp(x))


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Multi-head attention
        x1 = self.norm1(x)
        x2 = x + self.attn(x1, x1, x1)[0]
        # MLP
        x3 = self.norm2(x2)
        x3 = x2 + self.mlp(x3)
        return x3


class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, num_heads=12, mlp_dim=3072, num_layers=12, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        # Transformer Encoder Blocks
        self.encoder = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        batch_size = x.shape[0]

        # Append class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional encoding
        x = x + self.pos_embed
        x = self.dropout(x)

        # Pass through Transformer layers
        for layer in self.encoder:
            x = layer(x)

        x = self.norm(x)
        x = torch.mean(x, dim = 1)
        # print(x.shape)
        # cls_token_final = x[:, 0]

        # Classification head
        x = F.sigmoid(self.head(x))
        return x

In [6]:
def progressbar(max, iter, start_time, dash_len = 50, starter_txt = ''):
    f = starter_txt + '[' + ''.join(['-']*dash_len)+']'
    stdout.write('\r')
    f = f.replace('-', '#', int((iter+1)/max*dash_len))
    f += f'{(iter+1)/max * 100:.2f} ### {time.time()-start_time:.1f} elapsed.'
    stdout.write(f)
    stdout.flush()


@torch.no_grad()
def test(model, device, test_dataloader, criterion, threshold = .5):
    model.eval()
    dsize = len(test_dataloader)
    cm = None
    outer_loss = 0
    start_time = time.time()
    for i, (images, labels) in enumerate(test_dataloader):
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        outer_loss += criterion(output, labels).cpu().item()
        predict = (output >= threshold).float()
        ccm = ConfuseMatrix(labels.cpu().numpy(), predict.cpu().numpy())
        progressbar(dsize, i, start_time, 10)
    cm += ccm
    return outer_loss, cm

In [7]:
## hyper parameters
epoch = 100
img_size = 224
patch_size = 16
embed_size = 160
att_head = 8
mlp_size = embed_size
dropout = 0
tblock = 2

sf = f'./models/ViT_{epoch}_{img_size}_{patch_size}_{embed_size}_{att_head}_{mlp_size}_{tblock}_{dropout}.pth'
sf_best = f'./models/ViT_{epoch}_{img_size}_{patch_size}_{embed_size}_{att_head}_{mlp_size}_{tblock}_{dropout}_best.pth'
assert os.path.exists(sf) == True
assert os.path.exists(sf_best) == True

In [8]:
model = VisionTransformer(img_size, patch_size, 3, 8, embed_size, att_head, mlp_size, tblock, dropout*.1)
model = model.to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

In [9]:
def load_checkpoint(filename="model_checkpoint.pth", device = 'cuda'):
    print("=> Loading checkpoint")
    checkpoint = torch.load(filename, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return epoch, loss

In [10]:
epoch, loss = load_checkpoint(sf)
print(f"epoch = {epoch}, loss = {loss:.2f}")
loss, cm = test(model, device, test_dataloader, criterion)
print(test_dataloader)

print('\naccuracy = ', cm.acc())

print(f'loss = {loss:.2f}')

print('specificity = ', cm.spec().round(2))

print('sensitivity = ', cm.sens().round(2))

print('precision = ', cm.prec().round(2))

=> Loading checkpoint
epoch = 100, loss = 11.28
[##########]100.00 ### 7.8 elapsed.<torch.utils.data.dataloader.DataLoader object at 0x000001F10C464310>

accuracy =  0.4482757074911354
loss = 2.15
specificity =  [0.29 0.82 1.   0.93 1.   1.   0.92 1.  ]
sensitivity =  [0.8  0.57 0.   0.   0.   0.   0.5  0.  ]
precision =  [0.4  0.67 0.   0.   0.   0.   0.5  0.  ]


In [11]:
epoch, loss = load_checkpoint(sf_best)
print(f"epoch = {epoch}, loss = {loss:.2f}")
loss, cm = test(model, device, test_dataloader, criterion)
print(test_dataloader)

print('\naccuracy = ', cm.acc())

print(f'loss = {loss:.2f}')

print('specificity = ', cm.spec().round(2))

print('sensitivity = ', cm.sens().round(2))

print('precision = ', cm.prec().round(2))

=> Loading checkpoint
epoch = 80, loss = 11.82
[##########]100.00 ### 7.7 elapsed.<torch.utils.data.dataloader.DataLoader object at 0x000001F10C464310>

accuracy =  0.35483859521335637
loss = 2.13
specificity =  [0.33 0.43 1.   0.92 1.   1.   0.92 1.  ]
sensitivity =  [0.46 0.71 0.   0.   0.   0.   0.   0.  ]
precision =  [0.37 0.38 0.   0.   0.   0.   0.   0.  ]
