# Download OASIS DATASet

In [4]:
from huggingface_hub import list_models

# List up to 10 models with "bert" in their name
bert_models = list_models(search="retfound", limit=10)
for m in bert_models:
    print(m.modelId)


open-eye/RETFound_MAE
bitfount/RETFound_MAE
bitfount/RETFound_MAE_OCT
bitfount/RETFound_MAE_OCT_CNV_DME_DRU
Unified/RETfound_eyepacs
bswift/RETfound_eyepacs_DR
sebasmos/retfound-finetuned-lora-retfound
bitfount/RETFound_DR_IDRID
calumburnstone/RETFoundtest


In [None]:
import clip
print(clip.available_models())

  from .autonotebook import tqdm as notebook_tqdm


['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ninadaithal/imagesoasis")

print("Path to dataset files:", path)

Path to dataset files: /home/lab308/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1


In [None]:
import torch
import os
from tqdm import tqdm
import time
import torch.nn as nn

### Load Dataset
Data_size: 224 x 224
1. Non demented: 6,7222
2. mild demented: 5002
3. moderate demented: 488
4. very demented: 1,3725

In [None]:
# Model settings
import clip
from timm import optim

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
criteria = nn.CrossEntropyLoss()

In [None]:
# import dataset
from dataset import BasicDataset
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import datasets, transforms

train_dataset = datasets.ImageFolder(root='data/train', transform=preprocess)

total_size = len(train_dataset)
train_size = int(0.8*total_size)
val_size = total_size - train_size

train_set, val_set = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=True)

In [None]:
def load_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if os.path.exists(checkpoint_path):
        print(f"Load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch']
        print(f"Loaded checkpoint from epoch {epoch}")

def save_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, checkpoint_path + f"/checkpoint_{epoch}.pth")

In [None]:
# Training
import copy
finetune_epoch = 10
start_epoch = 0
best_loss = 1000
best_weights = copy.deepcopy(model.state_dict())
cathegories = ["Non Demented", "Mild Demeted", "Moderate Demented", "Very Mild Demented"]
localtime = time.asctime( time.localtime(time.time()) )
save_model_path = os.path.join("save_models", "CLIP")
os.makedirs(save_model_path, exist_ok=True)

model.train()

for epoch in range(start_epoch, finetune_epoch):
    epoch_loss = 0

    print(f"Epoch {epoch+1}/{finetune_epoch}", localtime)
    print("-" * len("Epoch {}/{}".format(epoch+1, finetune_epoch)))

    for batch in tqdm(train_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)
        
        optimizer.zero_grad()

        logits_img, logits_text = model(img, clip.tokenize(cathegories).to(device))
        probs = logits_img.softmax(dim=-1).to(torch.float32)
        loss = criteria(probs, label)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    print(f"Loss: {epoch_loss/len(train_loader)}")

    #Validation
    val_loss = 0
    model.eval()

    with torch.no_grad():
        for batch in tqdm(val_loader):
            img, label = batch
            img, label = img.to(device), clip.tokenize(label).to(device)
            

            logits_img, logits_text = model(img, label)
            probs = logits_img.softmax(dim=-1).detach().to(torch.float32)
            val_loss = criteria(probs, label)

            if val_loss < best_loss:
                best_loss = val_loss
                best_weights = copy.deepcopy(model.state_dict())
torch.save(best_weights, os.path.join("best_weights_{val_loss}_epoch_{epoch+1}.pth"))       
        
        

In [None]:
#validation
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score


test_loss = 0
model.eval()
#model.load_state_dict(best_weights)
cathegories = ["Non Demented", "Mild Demeted", "Moderate Demented", "Very Mild Demented"]

# Initialize lists to store true labels and predictions
all_labels = []
all_preds = []

with torch.no_grad():
    for batch in tqdm(val_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)

        #image_features = model.encode_image(img)
        #text_features = model.encode_text(clip.tokenize(cathegories).to(device))

        logits_img, logits_text = model(img, clip.tokenize(cathegories).to(device))
        #probs = logits_img.softmax(dim=-1).cpu().numpy()
        loss = criteria(logits_img, label)
        
        # Store the true labels and predictions
        preds = logits_img.argmax(dim=1).cpu().numpy()
        all_labels.extend(label.cpu().numpy())
        all_preds.extend(preds)
        
# Calculate precision, recall, F1 score, and accuracy
precision = precision_score(all_labels, all_preds, average='weighted')
recall = recall_score(all_labels, all_preds, average='weighted')
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

print("Test loss:", loss.item())

100%|██████████| 541/541 [04:28<00:00,  2.02it/s]

Precision: 0.6094
Recall: 0.6588
F1 Score: 0.6274
Accuracy: 0.6588
Test loss: 0.90625



