# Download OASIS DATASet

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ninadaithal/imagesoasis")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: /home/user/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1


In [2]:
import torch
import os
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
from PIL import Image
import torch.nn as nn

### Load Dataset
Data_size: 224 x 224
1. Non demented: 6,7222
2. mild demented: 5002
3. moderate demented: 488
4. very demented: 1,3725

In [3]:
# import dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),  # resize to 224x224
    transforms.ToTensor(),          # convert PIL image to PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )  # standard ImageNet normalization
])

train_dataset = datasets.ImageFolder(root='data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

total_size = len(train_dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
finetune_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)


In [4]:
# Model settings
from timm.optim import optim_factory
from Model.ViT.models_vit import VisionTransformer
from Model.ViT.models_mae import MaskedAutoencoderViT
from PIL import Image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

weight_path = os.path.join("Model", "ViT", "mae_pretrain_vit_base.pth")
model_mae = MaskedAutoencoderViT(embed_dim=768, depth=12, num_heads=12)
check_point = torch.load(weight_path)
check_point_model = check_point['model']
model_mae.load_state_dict(check_point_model, strict=False)


param_groups = optim_factory.add_weight_decay(model_mae, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=1.5e-4, betas=(0.9, 0.95))

mseloss = nn.MSELoss()

model_mae = model_mae.to(device)

In [5]:
def load_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/RETfound/{stage}"
    
    if os.path.exists(checkpoint_path):
        print(f"Load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch']
        print(f"Loaded checkpoint from epoch {epoch}")

def save_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, checkpoint_path + f"/checkpoint_{epoch}.pth")

# Pretrain fcmae

In [None]:
#pretrain mae on imagenet
pretrain_epoch = 2000
start_epoch = 0


if start_epoch:
    load_checkpoints(start_epoch, model_mae, optimizer, stage="imagenet")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    tqdm.write('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch,localtime))
    tqdm.write('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch, localtime)))

    folder_name = os.path.join("see_image", "imagent")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(imagenet_loader):
        img, label = batch
        img = img.to(device)


            
        model_mae.train()
        model_autoencoder.eval()
        optimizer.zero_grad()


        loss, pred, mask = model_mae(img)
            
        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)

        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch, model_mae, optimizer, stage="imagenet")

In [None]:
#pretrain mae on oasis dataset
pretrain_epoch = 2000
start_epoch = 0


if start_epoch:
    load_checkpoints(start_epoch-1, model_mae, optimizer, stage="pretrain")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    tqdm.write('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch,localtime))
    tqdm.write('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch, localtime)))

    folder_name = os.path.join("see_image", "pre-train")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(train_loader):
        img, label = batch
        img = img.to(device)


        loss, pred, mask = model_mae(img)
            
        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)
            
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch+1, model=model_mae, optimizer=optimizer, stage="pretrain")

In [6]:
from transformers import ViTForImageClassification

model_autoencoder = ViTForImageClassification.from_pretrained('facebook/deit-base-patch16-224')
model_autoencoder.classifier = nn.Linear(model_autoencoder.config.hidden_size, 4)
model_autoencoder = model_autoencoder.to(device)

param_groups = optim_factory.add_weight_decay(model_autoencoder, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=1.5e-4, betas=(0.9, 0.95))

In [7]:
# finetune classification
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
model_mae = None
torch.cuda.empty_cache()
pretrain_epoch = 25
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
start_epoch = 0

ce_loss = nn.CrossEntropyLoss()

if start_epoch:
    load_checkpoints(start_epoch-1, model_mae, optimizer, stage="finetune")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime)))

    folder_name = os.path.join("see_image", "warm")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()

    #model_mae.eval()
    model_autoencoder.train()
    
    
    for batch in tqdm(finetune_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)


        outputs = model_autoencoder(img)
        loss = ce_loss(outputs.logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
        
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')
    epoch_loss = epoch_loss / len(finetune_loader)
    
    with open("Deit_train.txt", 'a+') as f:
        f.write(f"{epoch_loss}\n")
        
    #validation



    test_loss = 0
    model_autoencoder.eval()

    # Initialize lists to store true labels and predictions
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in tqdm(val_loader):
            img, label = batch
            img, label = img.to(device), label.to(device)

            #image_features = model.encode_image(img)
            #text_features = model.encode_text(clip.tokenize(cathegories).to(device))

            outputs = model_autoencoder(img)
            #probs = logits_img.softmax(dim=-1).cpu().numpy()
            loss = ce_loss(outputs.logits, label)

            
            # Store the true labels and predictions
            preds = outputs.logits.argmax(dim=1).cpu().numpy()
            all_labels.extend(label.cpu().numpy())
            all_preds.extend(preds)
            test_loss += loss.item()
    test_loss = test_loss / len(val_loader)
    with open("Deit_val.txt",  "a+") as f:
        f.write(f"{test_loss}\n")
            
    # Calculate precision, recall, F1 score, and accuracy
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    accuracy = accuracy_score(all_labels, all_preds)

    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

    print("Test loss:", test_loss)

Epoch: 1/25 --- < Starting Time : Sat Dec 21 17:01:58 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:34<00:00, 15.04it/s]


Epoch1 loss : 
 pretrain loss : 2341.689697265625


100%|██████████| 2161/2161 [01:04<00:00, 33.27it/s]


Precision: 0.9436
Recall: 0.9452
F1 Score: 0.9425
Accuracy: 0.9452
Test loss: 0.17036653988232275
Epoch: 2/25 --- < Starting Time : Sat Dec 21 17:12:40 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:26<00:00, 15.25it/s]


Epoch2 loss : 
 pretrain loss : 915.5570678710938


100%|██████████| 2161/2161 [01:03<00:00, 33.87it/s]


Precision: 0.9749
Recall: 0.9747
F1 Score: 0.9741
Accuracy: 0.9747
Test loss: 0.07963207234829545
Epoch: 3/25 --- < Starting Time : Sat Dec 21 17:23:12 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:46<00:00, 14.75it/s]


Epoch3 loss : 
 pretrain loss : 612.7422485351562


100%|██████████| 2161/2161 [01:03<00:00, 34.25it/s]


Precision: 0.9630
Recall: 0.9542
F1 Score: 0.9562
Accuracy: 0.9542
Test loss: 0.1261129685612835
Epoch: 4/25 --- < Starting Time : Sat Dec 21 17:34:04 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:45<00:00, 14.76it/s]


Epoch4 loss : 
 pretrain loss : 550.233642578125


100%|██████████| 2161/2161 [01:04<00:00, 33.58it/s]


Precision: 0.9738
Recall: 0.9736
F1 Score: 0.9731
Accuracy: 0.9736
Test loss: 0.0965578322934732
Epoch: 5/25 --- < Starting Time : Sat Dec 21 17:44:56 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:46<00:00, 14.74it/s]


Epoch5 loss : 
 pretrain loss : 512.3602905273438


100%|██████████| 2161/2161 [01:03<00:00, 34.01it/s]


Precision: 0.9829
Recall: 0.9827
F1 Score: 0.9825
Accuracy: 0.9827
Test loss: 0.0533166986725211
Epoch: 6/25 --- < Starting Time : Sat Dec 21 17:55:49 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:44<00:00, 14.79it/s]


Epoch6 loss : 
 pretrain loss : 501.3293762207031


100%|██████████| 2161/2161 [01:03<00:00, 34.11it/s]


Precision: 0.9836
Recall: 0.9835
F1 Score: 0.9832
Accuracy: 0.9835
Test loss: 0.05299413266077475
Epoch: 7/25 --- < Starting Time : Sat Dec 21 18:06:39 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:50<00:00, 14.64it/s]


Epoch7 loss : 
 pretrain loss : 464.352783203125


100%|██████████| 2161/2161 [01:03<00:00, 33.77it/s]


Precision: 0.9895
Recall: 0.9895
F1 Score: 0.9893
Accuracy: 0.9895
Test loss: 0.043076391866755576
Epoch: 8/25 --- < Starting Time : Sat Dec 21 18:17:35 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:46<00:00, 14.75it/s]


Epoch8 loss : 
 pretrain loss : 450.60369873046875


100%|██████████| 2161/2161 [01:03<00:00, 34.05it/s]


Precision: 0.9906
Recall: 0.9903
F1 Score: 0.9904
Accuracy: 0.9903
Test loss: 0.029329715198570694
Epoch: 9/25 --- < Starting Time : Sat Dec 21 18:28:27 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [09:49<00:00, 14.67it/s]


Epoch9 loss : 
 pretrain loss : 410.21026611328125


100%|██████████| 2161/2161 [01:03<00:00, 33.91it/s]


Precision: 0.9842
Recall: 0.9833
F1 Score: 0.9836
Accuracy: 0.9833
Test loss: 0.05482148844568941
Epoch: 10/25 --- < Starting Time : Sat Dec 21 18:39:22 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:49<00:00, 14.66it/s]


Epoch10 loss : 
 pretrain loss : 366.26910400390625


100%|██████████| 2161/2161 [01:03<00:00, 34.06it/s]


Precision: 0.9917
Recall: 0.9914
F1 Score: 0.9914
Accuracy: 0.9914
Test loss: 0.034307790474295725
Epoch: 11/25 --- < Starting Time : Sat Dec 21 18:50:18 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:47<00:00, 14.72it/s]


Epoch11 loss : 
 pretrain loss : 368.1532897949219


100%|██████████| 2161/2161 [01:03<00:00, 33.92it/s]


Precision: 0.9928
Recall: 0.9927
F1 Score: 0.9928
Accuracy: 0.9927
Test loss: 0.023094751533894634
Epoch: 12/25 --- < Starting Time : Sat Dec 21 19:01:11 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:48<00:00, 14.69it/s]


Epoch12 loss : 
 pretrain loss : 368.3670959472656


100%|██████████| 2161/2161 [01:03<00:00, 33.84it/s]


Precision: 0.9847
Recall: 0.9846
F1 Score: 0.9846
Accuracy: 0.9846
Test loss: 0.04835356811044966
Epoch: 13/25 --- < Starting Time : Sat Dec 21 19:12:05 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:45<00:00, 14.76it/s]


Epoch13 loss : 
 pretrain loss : 361.8525695800781


100%|██████████| 2161/2161 [01:03<00:00, 34.02it/s]


Precision: 0.9854
Recall: 0.9854
F1 Score: 0.9853
Accuracy: 0.9854
Test loss: 0.04521650358014648
Epoch: 14/25 --- < Starting Time : Sat Dec 21 19:22:57 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:49<00:00, 14.66it/s]


Epoch14 loss : 
 pretrain loss : 365.41961669921875


100%|██████████| 2161/2161 [01:03<00:00, 33.89it/s]


Precision: 0.9914
Recall: 0.9913
F1 Score: 0.9913
Accuracy: 0.9913
Test loss: 0.03371794105842258
Epoch: 15/25 --- < Starting Time : Sat Dec 21 19:33:53 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:48<00:00, 14.70it/s]


Epoch15 loss : 
 pretrain loss : 372.59490966796875


100%|██████████| 2161/2161 [01:03<00:00, 34.26it/s]


Precision: 0.9859
Recall: 0.9859
F1 Score: 0.9849
Accuracy: 0.9859
Test loss: 0.05915397962172943
Epoch: 16/25 --- < Starting Time : Sat Dec 21 19:44:46 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:48<00:00, 14.70it/s]


Epoch16 loss : 
 pretrain loss : 335.3423156738281


100%|██████████| 2161/2161 [01:03<00:00, 34.11it/s]


Precision: 0.9852
Recall: 0.9839
F1 Score: 0.9838
Accuracy: 0.9839
Test loss: 0.06089903158134997
Epoch: 17/25 --- < Starting Time : Sat Dec 21 19:55:39 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:45<00:00, 14.75it/s]


Epoch17 loss : 
 pretrain loss : 340.43804931640625


100%|██████████| 2161/2161 [01:03<00:00, 33.88it/s]


Precision: 0.9901
Recall: 0.9899
F1 Score: 0.9900
Accuracy: 0.9899
Test loss: 0.03132240874184255
Epoch: 18/25 --- < Starting Time : Sat Dec 21 20:06:31 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:48<00:00, 14.69it/s]


Epoch18 loss : 
 pretrain loss : 341.41021728515625


100%|██████████| 2161/2161 [01:03<00:00, 34.22it/s]


Precision: 0.9905
Recall: 0.9905
F1 Score: 0.9905
Accuracy: 0.9905
Test loss: 0.032844705094591226
Epoch: 19/25 --- < Starting Time : Sat Dec 21 20:17:25 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:46<00:00, 14.74it/s]


Epoch19 loss : 
 pretrain loss : 313.7786865234375


100%|██████████| 2161/2161 [01:03<00:00, 34.03it/s]


Precision: 0.9938
Recall: 0.9938
F1 Score: 0.9937
Accuracy: 0.9938
Test loss: 0.027584747023547913
Epoch: 20/25 --- < Starting Time : Sat Dec 21 20:28:17 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:47<00:00, 14.72it/s]


Epoch20 loss : 
 pretrain loss : 312.65582275390625


100%|██████████| 2161/2161 [01:03<00:00, 33.80it/s]


Precision: 0.9869
Recall: 0.9868
F1 Score: 0.9868
Accuracy: 0.9868
Test loss: 0.04821373227776825
Epoch: 21/25 --- < Starting Time : Sat Dec 21 20:39:11 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:47<00:00, 14.72it/s]


Epoch21 loss : 
 pretrain loss : 314.2381286621094


100%|██████████| 2161/2161 [01:03<00:00, 33.87it/s]


Precision: 0.9908
Recall: 0.9906
F1 Score: 0.9906
Accuracy: 0.9906
Test loss: 0.03249695347190434
Epoch: 22/25 --- < Starting Time : Sat Dec 21 20:50:04 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:46<00:00, 14.73it/s]


Epoch22 loss : 
 pretrain loss : 334.6990966796875


100%|██████████| 2161/2161 [01:03<00:00, 33.97it/s]


Precision: 0.9868
Recall: 0.9865
F1 Score: 0.9866
Accuracy: 0.9865
Test loss: 0.0512643877431382
Epoch: 23/25 --- < Starting Time : Sat Dec 21 21:00:57 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:47<00:00, 14.72it/s]


Epoch23 loss : 
 pretrain loss : 329.301025390625


100%|██████████| 2161/2161 [01:03<00:00, 34.14it/s]


Precision: 0.9893
Recall: 0.9883
F1 Score: 0.9886
Accuracy: 0.9883
Test loss: 0.04163101706803671
Epoch: 24/25 --- < Starting Time : Sat Dec 21 21:11:50 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:48<00:00, 14.69it/s]


Epoch24 loss : 
 pretrain loss : 310.4465026855469


100%|██████████| 2161/2161 [01:03<00:00, 33.86it/s]


Precision: 0.9872
Recall: 0.9868
F1 Score: 0.9869
Accuracy: 0.9868
Test loss: 0.04626175354122405
Epoch: 25/25 --- < Starting Time : Sat Dec 21 21:22:44 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [09:45<00:00, 14.77it/s]


Epoch25 loss : 
 pretrain loss : 340.0849609375


100%|██████████| 2161/2161 [01:03<00:00, 34.19it/s]

Precision: 0.9929
Recall: 0.9928
F1 Score: 0.9928
Accuracy: 0.9928
Test loss: 0.022847359690412378



