# Download OASIS DATASet

In [10]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ninadaithal/imagesoasis")

print("Path to dataset files:", path)

Path to dataset files: /home/lab308/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1


In [11]:
import torch
import os
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
from PIL import Image
import torch.nn as nn

### Load Dataset
Data_size: 224 x 224
1. Non demented: 6,7222
2. mild demented: 5002
3. moderate demented: 488
4. very demented: 1,3725

In [12]:
# import dataset
from dataset import BasicDataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),  # resize to 224x224
    transforms.ToTensor(),          # convert PIL image to PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )  # standard ImageNet normalization
])

train_dataset = datasets.ImageFolder(root='data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)

imagenet_loader = datasets.ImageNet(root='data/imagenet', split='train', transform=transform)
imagenet_loader = DataLoader(imagenet_loader, batch_size=512, shuffle=True)

In [13]:
# Model settings
from timm.optim import optim_factory
from Model.ViT.models_vit import VisionTransformer
from Model.ViT.models_mae import MaskedAutoencoderViT
from PIL import Image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

weight_path = os.path.join("Model", "ViT", "mae_pretrain_vit_base.pth")
model_mae = MaskedAutoencoderViT(embed_dim=768, depth=12, num_heads=12)
check_point = torch.load(weight_path)
check_point_model = check_point['model']
model_mae.load_state_dict(check_point_model, strict=False)

model_autoencoder = VisionTransformer(embed_dim=768, depth=12, num_heads=12, num_classes=4)

param_groups = optim_factory.add_weight_decay(model_mae, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=1.5e-4, betas=(0.9, 0.95))

mseloss = nn.MSELoss()

model_mae = model_mae.to(device)
model_autoencoder = model_autoencoder.to(device)

In [14]:
for name, params in model_mae.named_parameters():
    print(f"{name}: {params.mean().item()}")

cls_token: -0.0016049178084358573
pos_embed: 0.4599889814853668
mask_token: -0.0013113151071593165
decoder_pos_embed: 0.4594338536262512
patch_embed.proj.weight: 0.00012506816710811108
patch_embed.proj.bias: 0.013790666125714779
blocks.0.norm1.weight: 0.2868507504463196
blocks.0.norm1.bias: -0.007356846239417791
blocks.0.attn.qkv.weight: -8.371970034204423e-05
blocks.0.attn.qkv.bias: -0.020099656656384468
blocks.0.attn.proj.weight: 7.620392716489732e-05
blocks.0.attn.proj.bias: -0.01881934516131878
blocks.0.norm2.weight: 0.8042056560516357
blocks.0.norm2.bias: -0.020366661250591278
blocks.0.mlp.fc1.weight: -0.000608887814451009
blocks.0.mlp.fc1.bias: -1.0230910778045654
blocks.0.mlp.fc2.weight: 4.994144183001481e-05
blocks.0.mlp.fc2.bias: 0.033312924206256866
blocks.1.norm1.weight: 0.5066466331481934
blocks.1.norm1.bias: 0.0013204384595155716
blocks.1.attn.qkv.weight: 1.9829030861728825e-05
blocks.1.attn.qkv.bias: 0.02302500605583191
blocks.1.attn.proj.weight: 6.507000307465205e-06
blo

In [15]:
for key in check_point_model.keys():
    print(key)

cls_token
pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.qkv.weight
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.qkv.weight
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.fc1.weight
blocks.1.mlp.fc1.bias
blocks.1.mlp.fc2.weight
blocks.1.mlp.fc2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.qkv.weight
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.fc1.weight
blocks.2.mlp.fc1.bias
blocks.2.mlp.fc2.weight
blocks.2.mlp.fc2.bias
blocks.3.norm1.weight
blocks.3.norm1.bias
blocks.3.attn.qkv.weight
blocks.3.attn.proj.weight
blocks.3.attn.proj.bias
blocks.3.norm2.weight
blocks.3.norm2.bias
blocks.3.mlp.fc1.weigh

In [16]:
def load_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/RETfound/{stage}"
    
    if os.path.exists(checkpoint_path):
        print(f"Load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch']
        print(f"Loaded checkpoint from epoch {epoch}")

def save_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, checkpoint_path + f"/checkpoint_{epoch}.pth")

# Pretrain fcmae

In [None]:
#pretrain mae on imagenet
pretrain_epoch = 2000
start_epoch = 0


if start_epoch:
    load_checkpoints(start_epoch, model_mae, optimizer, stage="imagenet")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    tqdm.write('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch,localtime))
    tqdm.write('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch, localtime)))

    folder_name = os.path.join("see_image", "imagent")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(imagenet_loader):
        img, label = batch
        img = img.to(device)


            
        model_mae.train()
        model_autoencoder.eval()
        optimizer.zero_grad()


        loss, pred, mask = model_mae(img)
            
        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)

        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch, model_mae, optimizer, stage="imagenet")

In [None]:
#pretrain mae on oasis dataset
pretrain_epoch = 2000
start_epoch = 0


if start_epoch:
    load_checkpoints(start_epoch-1, model_mae, optimizer, stage="pretrain")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    tqdm.write('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch,localtime))
    tqdm.write('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch, localtime)))

    folder_name = os.path.join("see_image", "pre-train")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(train_loader):
        img, label = batch
        img = img.to(device)


        loss, pred, mask = model_mae(img)
            
        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)
            
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch+1, model=model_mae, optimizer=optimizer, stage="pretrain")

In [21]:
from transformers import ViTForImageClassification

model_autoencoder = ViTForImageClassification.from_pretrained('facebook/deit-base-patch16-224')
model_autoencoder.classifier = nn.Linear(model_autoencoder.config.hidden_size, 4)
model_autoencoder = model_autoencoder.to(device)


In [None]:
# finetune classification on Oasis dataset
pretrain_epoch = 2000
start_epoch = 0
ce_loss = nn.CrossEntropyLoss()
model_mae.eval()
model_autoencoder.train()

if start_epoch:
    load_checkpoints(start_epoch-1, model_mae, optimizer, stage="finetune")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    tqdm.write('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch,localtime))
    tqdm.write('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1,pretrain_epoch, localtime)))

    folder_name = os.path.join("see_image", "pre-train")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(train_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)

        with torch.no_grad():
            pred = model_autoencoder(img)
            loss = ce_loss(pred.logits, label)
            
        
            
    epoch_loss += loss
    loss.backward()
    optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')

    if (epoch+1) % 100 == 0:
        save_checkpoints(epoch+1, model=model_mae, optimizer=optimizer, stage="finetune")

Epoch: 1/2000 --- < Starting Time : Thu Dec 19 16:33:52 2024 >
--------------------------------------------------------------


  0%|          | 0/169 [00:00<?, ?it/s]

  2%|▏         | 4/169 [00:09<07:05,  2.58s/it]