# Download OASIS DATASet

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ninadaithal/imagesoasis")

print("Path to dataset files:", path)

Path to dataset files: /home/lab308/.cache/kagglehub/datasets/ninadaithal/imagesoasis/versions/1


In [2]:
import torch
import os
from tqdm import tqdm
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
from PIL import Image
import torch.nn as nn

### Load Dataset
Data_size: 224 x 224
1. Non demented: 6,7222
2. mild demented: 5002
3. moderate demented: 488
4. very demented: 1,3725

In [3]:
# import dataset
from dataset import BasicDataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import AutoImageProcessor

transform = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),  # resize to 224x224
    transforms.ToTensor(),          # convert PIL image to PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )  # standard ImageNet normalization
])

preprocessor = AutoImageProcessor.from_pretrained("facebook/convnextv2-base-22k-224")

# import dataset
from dataset import BasicDataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.CenterCrop(size=(224, 224)),  # resize to 224x224
    transforms.ToTensor(),          # convert PIL image to PyTorch tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )  # standard ImageNet normalization
])

train_dataset = datasets.ImageFolder(root='data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

total_size = len(train_dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
finetune_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

imagenet_loader = datasets.ImageNet(root='data/imagenet', split='train', transform=transform)
imagenet_loader = DataLoader(imagenet_loader, batch_size=8, shuffle=True)

In [4]:
images, labels = next(iter(train_loader))
print("Image batch shape:", images.shape)
print("Label batch shape:", labels.shape)
print("Labels:", labels)


Image batch shape: torch.Size([8, 3, 224, 224])
Label batch shape: torch.Size([8])
Labels: tensor([2, 2, 2, 2, 2, 2, 2, 2])


In [5]:
# Model settings
from Model.ConvNeXtV2.convnextv2 import convnextv2_base as autoencoder
from Model.ConvNeXtV2.fcmae import convnextv2_base as mae
from timm.optim import optim_factory
from timm.models.layers import trunc_normal_

weight_path = os.path.join("Model", "ConvNeXtV2", "pretrained", "convnextv2_base_22k_224_ema.pt")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_mae = mae()

checkpoint = torch.load(weight_path, map_location='cpu')

print("Load pre-trained checkpoint from: %s" % weight_path)
checkpoint_model = checkpoint['model']
for k in ['head.weight', 'head.bias']:
    if k in checkpoint_model :
        print(f"Removing key {k} from pretrained checkpoint")
        del checkpoint_model[k]
        
#model_mae.load_state_dict(checkpoint_model, strict=False)


model_autoencoder = autoencoder(num_classes=4)
model_autoencoder.load_state_dict(checkpoint_model, strict=False)
# manually initialize fc layer
trunc_normal_(model_autoencoder.head.weight, std=2e-5)
torch.nn.init.constant_(model_autoencoder.head.bias, 0.)

param_groups = optim_factory.add_weight_decay(model_mae, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=1.5e-4, betas=(0.9, 0.95))

model_mae = model_mae.to(device)
model_autoencoder = model_autoencoder.to(device)

Load pre-trained checkpoint from: Model/ConvNeXtV2/pretrained/convnextv2_base_22k_224_ema.pt
Removing key head.weight from pretrained checkpoint
Removing key head.bias from pretrained checkpoint


In [6]:
for name, param in model_mae.named_parameters():
    print(f"{name}: {param.mean().item()}")

mask_token: -0.0004080507787875831
encoder.downsample_layers.0.0.weight: -0.015571651980280876
encoder.downsample_layers.0.0.bias: 0.0
encoder.downsample_layers.0.1.weight: 1.0
encoder.downsample_layers.0.1.bias: 0.0
encoder.downsample_layers.1.0.ln.weight: 1.0
encoder.downsample_layers.1.0.ln.bias: 0.0
encoder.downsample_layers.1.1.kernel: -5.830774534842931e-05
encoder.downsample_layers.1.1.bias: 0.0
encoder.downsample_layers.2.0.ln.weight: 1.0
encoder.downsample_layers.2.0.ln.bias: 0.0
encoder.downsample_layers.2.1.kernel: -6.119896625023102e-06
encoder.downsample_layers.2.1.bias: 0.0
encoder.downsample_layers.3.0.ln.weight: 1.0
encoder.downsample_layers.3.0.ln.bias: 0.0
encoder.downsample_layers.3.1.kernel: -2.589037194411503e-06
encoder.downsample_layers.3.1.bias: 0.0
encoder.stages.0.0.dwconv.kernel: -0.01506991870701313
encoder.stages.0.0.dwconv.bias: 0.0
encoder.stages.0.0.norm.ln.weight: 1.0
encoder.stages.0.0.norm.ln.bias: 0.0
encoder.stages.0.0.pwconv1.linear.weight: -0.0040

In [7]:
for key in checkpoint_model.keys():
    print(key)

downsample_layers.0.0.bias
downsample_layers.0.0.weight
downsample_layers.0.1.bias
downsample_layers.0.1.weight
downsample_layers.1.0.bias
downsample_layers.1.0.weight
downsample_layers.1.1.bias
downsample_layers.1.1.weight
downsample_layers.2.0.bias
downsample_layers.2.0.weight
downsample_layers.2.1.bias
downsample_layers.2.1.weight
downsample_layers.3.0.bias
downsample_layers.3.0.weight
downsample_layers.3.1.bias
downsample_layers.3.1.weight
norm.bias
norm.weight
stages.0.0.grn.beta
stages.0.0.grn.gamma
stages.0.0.dwconv.bias
stages.0.0.dwconv.weight
stages.0.0.norm.bias
stages.0.0.norm.weight
stages.0.0.pwconv1.bias
stages.0.0.pwconv1.weight
stages.0.0.pwconv2.bias
stages.0.0.pwconv2.weight
stages.0.1.grn.beta
stages.0.1.grn.gamma
stages.0.1.dwconv.bias
stages.0.1.dwconv.weight
stages.0.1.norm.bias
stages.0.1.norm.weight
stages.0.1.pwconv1.bias
stages.0.1.pwconv1.weight
stages.0.1.pwconv2.bias
stages.0.1.pwconv2.weight
stages.0.2.grn.beta
stages.0.2.grn.gamma
stages.0.2.dwconv.bias


In [8]:
def load_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if os.path.exists(checkpoint_path):
        print(f"Load checkpoint from {checkpoint_path}")
        checkpoint = torch.load(os.path.join(checkpoint_path, f"checkpoint_{epoch}.pth"))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print(f"Loaded checkpoint from epoch {epoch}")

def save_checkpoints(epoch, model, optimizer, stage):
    checkpoint_path = f"checkpoints/ConvNeXtV2/{stage}"
    
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, checkpoint_path + f"/checkpoint_{epoch}.pth")

In [None]:
# pretrain Imagenet1k
import torch.cuda


pretrain_epoch = 1600
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
start_epoch = 0

mse_loss = nn.MSELoss()

if start_epoch:
    load_checkpoints(start_epoch, model_mae, optimizer, stage="imagenet")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()

    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime)))

    folder_name = os.path.join("see_image", "warm")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(imagenet_loader):
        torch.cuda.empty_cache()
        img, label = batch
        img, label = img.to(device), label.to(device)


        loss, pred, mask = model_mae(img)
        loss = mse_loss(pred, img)    

        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)

        
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss/len(train_loader)}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch+1, model=model_mae, optimizer=optimizer, stage="imagenet")
cv2.destroyAllWindows()

# Pretrain fcmae

In [None]:
# finetune mae on oasis dataset
import torch.cuda


pretrain_epoch = 1600
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
start_epoch = 0

mse_loss = nn.MSELoss()

if start_epoch:
    load_checkpoints(start_epoch, model_mae, optimizer, stage="pretrain")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()

    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime)))

    folder_name = os.path.join("see_image", "warm")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()
    
    for batch in tqdm(train_loader):
        torch.cuda.empty_cache()
        img, label = batch
        img, label = img.to(device), label.to(device)


        loss, pred, mask = model_mae(img)
        loss = mse_loss(pred, img)    

        image = pred[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = img[0].cpu().detach().numpy().transpose(1, 2, 0)
        img_np = (img_np * 255).astype(np.uint8)
        pred_np = (image * 255).astype(np.uint8)

        cv2.imshow("pred img", pred_np)
        cv2.waitKey(1)
        cv2.imshow("targe img", img_np)
        cv2.waitKey(1)

        
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss/len(train_loader)}')

    if (epoch+1) % 10 == 0:
        save_checkpoints(epoch+1, model=model_mae, optimizer=optimizer, stage="pretrain")
cv2.destroyAllWindows()

In [9]:
model_mae = None
torch.cuda.empty_cache()
param_groups = optim_factory.add_weight_decay(model_autoencoder, 0.05)
optimizer = torch.optim.AdamW(param_groups, lr=1.5e-4, betas=(0.9, 0.95))

In [10]:
#Load warm model weight to semi-supervised model
from Model.ConvNeXtV2.utils_param import remap_checkpoint_keys

def load_model(source_model, target_model):
    
    checkpoint_model = source_model.state_dict()

    state_dict = target_model.state_dict()
    for k in ['head.weight', 'head.bias']:
        if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
            print(f"Removing key {k} from pretrained checkpoint")
            del checkpoint_model[k]
        

    # remove decoder weights
    checkpoint_model_keys = list(checkpoint_model.keys())
    for k in checkpoint_model_keys:
        if 'decoder' in k or 'mask_token'in k or \
           'proj' in k or 'pred' in k:
            print(f"Removing key {k} from pretrained checkpoint")
            del checkpoint_model[k]

    checkpoint_model = remap_checkpoint_keys(checkpoint_model)
    target_model.load_state_dict(checkpoint_model)

    return target_model

#model_autoencoder = load_model(model_mae, model_autoencoder)

In [11]:
# finetune classification
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
model_mae = None
torch.cuda.empty_cache()
pretrain_epoch = 25
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
start_epoch = 0

ce_loss = nn.CrossEntropyLoss()

if start_epoch:
    load_checkpoints(start_epoch-1, model_mae, optimizer, stage="finetune")

for epoch in range(start_epoch, pretrain_epoch):
    torch.cuda.empty_cache()
    epoch_loss = 0

    localtime = time.asctime( time.localtime(time.time()) )
    print('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime))
    print('-' * len('Epoch: {}/{} --- < Starting Time : {} >'.format(epoch + 1, pretrain_epoch,localtime)))

    folder_name = os.path.join("see_image", "warm")
    os.makedirs(folder_name, exist_ok=True)

    torch.cuda.empty_cache()

    #model_mae.eval()
    model_autoencoder.train()
    
    
    for batch in tqdm(finetune_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)


        logits = model_autoencoder(img)
        loss = ce_loss(logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss
    epoch_loss = epoch_loss / len(finetune_loader)
        
    print(f'Epoch{epoch+1} loss : \n pretrain loss : {epoch_loss}')
    
    with open("Convnextv2_train.txt", 'a+') as f:
        f.write(f"{epoch_loss}\n")
        
    #validation
    test_loss = 0
    model_autoencoder.eval()

    # Initialize lists to store true labels and predictions
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for batch in tqdm(val_loader):
            img, label = batch
            img, label = img.to(device), label.to(device)

            #image_features = model.encode_image(img)
            #text_features = model.encode_text(clip.tokenize(cathegories).to(device))

            logits = model_autoencoder(img)
            #probs = logits_img.softmax(dim=-1).cpu().numpy()
            loss = ce_loss(logits, label)

            
            # Store the true labels and predictions
            preds = logits.argmax(dim=1).cpu().numpy()
            all_labels.extend(label.cpu().numpy())
            all_preds.extend(preds)
            test_loss += loss.item()
    test_loss = test_loss / len(val_loader)
    with open("Convnextv2_val.txt",  "a+") as f:
        f.write(f"{test_loss}\n")
            
    # Calculate precision, recall, F1 score, and accuracy
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    accuracy = accuracy_score(all_labels, all_preds)

    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

    print("Test loss:", test_loss)

Epoch: 1/25 --- < Starting Time : Fri Dec 20 17:02:21 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [37:52<00:00,  3.80it/s]


Epoch1 loss : 
 pretrain loss : 0.5624855756759644


100%|██████████| 2161/2161 [03:04<00:00, 11.73it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Precision: 0.7190
Recall: 0.7833
F1 Score: 0.7488
Accuracy: 0.7833
Test loss: 0.5205876421194553
Epoch: 2/25 --- < Starting Time : Fri Dec 20 17:43:21 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:30<00:00,  3.74it/s]


Epoch2 loss : 
 pretrain loss : 0.4728028178215027


100%|██████████| 2161/2161 [03:03<00:00, 11.76it/s]
  _warn_prf(average, modifier, msg_start, len(result))


Precision: 0.7620
Recall: 0.7933
F1 Score: 0.7758
Accuracy: 0.7933
Test loss: 0.4512654781216141
Epoch: 3/25 --- < Starting Time : Fri Dec 20 18:24:59 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:28<00:00,  3.74it/s]


Epoch3 loss : 
 pretrain loss : 0.42120638489723206


100%|██████████| 2161/2161 [03:05<00:00, 11.65it/s]


Precision: 0.8148
Recall: 0.8294
F1 Score: 0.8161
Accuracy: 0.8294
Test loss: 0.38422484600117923
Epoch: 4/25 --- < Starting Time : Fri Dec 20 19:06:37 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:28<00:00,  3.74it/s]


Epoch4 loss : 
 pretrain loss : 0.30743837356567383


100%|██████████| 2161/2161 [03:04<00:00, 11.72it/s]


Precision: 0.9092
Recall: 0.9108
F1 Score: 0.9045
Accuracy: 0.9108
Test loss: 0.24211680069402947
Epoch: 5/25 --- < Starting Time : Fri Dec 20 19:48:14 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:27<00:00,  3.75it/s]


Epoch5 loss : 
 pretrain loss : 0.16052958369255066


100%|██████████| 2161/2161 [03:04<00:00, 11.74it/s]


Precision: 0.9537
Recall: 0.9531
F1 Score: 0.9524
Accuracy: 0.9531
Test loss: 0.12467066031977247
Epoch: 6/25 --- < Starting Time : Fri Dec 20 20:29:49 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:24<00:00,  3.75it/s]


Epoch6 loss : 
 pretrain loss : 0.08539946377277374


100%|██████████| 2161/2161 [03:03<00:00, 11.80it/s]


Precision: 0.9804
Recall: 0.9803
F1 Score: 0.9803
Accuracy: 0.9803
Test loss: 0.05434288508276807
Epoch: 7/25 --- < Starting Time : Fri Dec 20 21:11:19 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:20<00:00,  3.76it/s]


Epoch7 loss : 
 pretrain loss : 0.0574377179145813


100%|██████████| 2161/2161 [03:05<00:00, 11.63it/s]


Precision: 0.9780
Recall: 0.9781
F1 Score: 0.9777
Accuracy: 0.9781
Test loss: 0.06759797322807068
Epoch: 8/25 --- < Starting Time : Fri Dec 20 21:52:49 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:23<00:00,  3.75it/s]


Epoch8 loss : 
 pretrain loss : 0.04155855253338814


100%|██████████| 2161/2161 [03:03<00:00, 11.78it/s]


Precision: 0.9872
Recall: 0.9872
F1 Score: 0.9872
Accuracy: 0.9872
Test loss: 0.038824917398795865
Epoch: 9/25 --- < Starting Time : Fri Dec 20 22:34:20 2024 >
------------------------------------------------------------


100%|██████████| 8644/8644 [38:26<00:00,  3.75it/s]


Epoch9 loss : 
 pretrain loss : 0.03407704830169678


100%|██████████| 2161/2161 [03:04<00:00, 11.73it/s]


Precision: 0.9735
Recall: 0.9703
F1 Score: 0.9710
Accuracy: 0.9703
Test loss: 0.08971495766016924
Epoch: 10/25 --- < Starting Time : Fri Dec 20 23:15:55 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:27<00:00,  3.75it/s]


Epoch10 loss : 
 pretrain loss : 0.02813187800347805


100%|██████████| 2161/2161 [03:05<00:00, 11.66it/s]


Precision: 0.9843
Recall: 0.9839
F1 Score: 0.9838
Accuracy: 0.9839
Test loss: 0.054336440542615085
Epoch: 11/25 --- < Starting Time : Fri Dec 20 23:57:31 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:27<00:00,  3.75it/s]


Epoch11 loss : 
 pretrain loss : 0.026227395981550217


100%|██████████| 2161/2161 [03:04<00:00, 11.73it/s]


Precision: 0.9881
Recall: 0.9877
F1 Score: 0.9878
Accuracy: 0.9877
Test loss: 0.040719108414325515
Epoch: 12/25 --- < Starting Time : Sat Dec 21 00:39:06 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:30<00:00,  3.74it/s]


Epoch12 loss : 
 pretrain loss : 0.023013759404420853


100%|██████████| 2161/2161 [03:06<00:00, 11.60it/s]


Precision: 0.9891
Recall: 0.9866
F1 Score: 0.9874
Accuracy: 0.9866
Test loss: 0.04375158126526491
Epoch: 13/25 --- < Starting Time : Sat Dec 21 01:20:46 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:29<00:00,  3.74it/s]


Epoch13 loss : 
 pretrain loss : 0.02165011130273342


100%|██████████| 2161/2161 [03:03<00:00, 11.78it/s]


Precision: 0.9936
Recall: 0.9935
F1 Score: 0.9935
Accuracy: 0.9935
Test loss: 0.020821082152135007
Epoch: 14/25 --- < Starting Time : Sat Dec 21 02:02:23 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:31<00:00,  3.74it/s]


Epoch14 loss : 
 pretrain loss : 0.02007344737648964


100%|██████████| 2161/2161 [03:05<00:00, 11.66it/s]


Precision: 0.9928
Recall: 0.9928
F1 Score: 0.9928
Accuracy: 0.9928
Test loss: 0.02632679514144273
Epoch: 15/25 --- < Starting Time : Sat Dec 21 02:44:04 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:31<00:00,  3.74it/s]


Epoch15 loss : 
 pretrain loss : 0.019127612933516502


100%|██████████| 2161/2161 [03:04<00:00, 11.73it/s]


Precision: 0.9909
Recall: 0.9909
F1 Score: 0.9909
Accuracy: 0.9909
Test loss: 0.029798990421152593
Epoch: 16/25 --- < Starting Time : Sat Dec 21 03:25:43 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:30<00:00,  3.74it/s]


Epoch16 loss : 
 pretrain loss : 0.017667638137936592


100%|██████████| 2161/2161 [03:06<00:00, 11.59it/s]


Precision: 0.9941
Recall: 0.9940
F1 Score: 0.9940
Accuracy: 0.9940
Test loss: 0.015920898362790543
Epoch: 17/25 --- < Starting Time : Sat Dec 21 04:07:24 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:33<00:00,  3.74it/s]


Epoch17 loss : 
 pretrain loss : 0.0171979870647192


100%|██████████| 2161/2161 [03:03<00:00, 11.77it/s]


Precision: 0.9940
Recall: 0.9939
F1 Score: 0.9939
Accuracy: 0.9939
Test loss: 0.02166975298532725
Epoch: 18/25 --- < Starting Time : Sat Dec 21 04:49:05 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:31<00:00,  3.74it/s]


Epoch18 loss : 
 pretrain loss : 0.0161428302526474


100%|██████████| 2161/2161 [03:05<00:00, 11.62it/s]


Precision: 0.9962
Recall: 0.9962
F1 Score: 0.9962
Accuracy: 0.9962
Test loss: 0.011463783015853797
Epoch: 19/25 --- < Starting Time : Sat Dec 21 05:30:46 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:30<00:00,  3.74it/s]


Epoch19 loss : 
 pretrain loss : 0.01737247034907341


100%|██████████| 2161/2161 [03:04<00:00, 11.74it/s]


Precision: 0.9957
Recall: 0.9957
F1 Score: 0.9957
Accuracy: 0.9957
Test loss: 0.013433803570734308
Epoch: 20/25 --- < Starting Time : Sat Dec 21 06:12:24 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:27<00:00,  3.75it/s]


Epoch20 loss : 
 pretrain loss : 0.01596021093428135


100%|██████████| 2161/2161 [03:05<00:00, 11.66it/s]


Precision: 0.9914
Recall: 0.9913
F1 Score: 0.9913
Accuracy: 0.9913
Test loss: 0.03154834559909567
Epoch: 21/25 --- < Starting Time : Sat Dec 21 06:54:00 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:32<00:00,  3.74it/s]


Epoch21 loss : 
 pretrain loss : 0.016267351806163788


100%|██████████| 2161/2161 [03:03<00:00, 11.78it/s]


Precision: 0.9948
Recall: 0.9947
F1 Score: 0.9947
Accuracy: 0.9947
Test loss: 0.015653679306171016
Epoch: 22/25 --- < Starting Time : Sat Dec 21 07:35:40 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:31<00:00,  3.74it/s]


Epoch22 loss : 
 pretrain loss : 0.014237558469176292


100%|██████████| 2161/2161 [03:05<00:00, 11.64it/s]


Precision: 0.9968
Recall: 0.9968
F1 Score: 0.9968
Accuracy: 0.9968
Test loss: 0.011285787459334572
Epoch: 23/25 --- < Starting Time : Sat Dec 21 08:17:20 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:28<00:00,  3.74it/s]


Epoch23 loss : 
 pretrain loss : 0.014600737020373344


100%|██████████| 2161/2161 [03:04<00:00, 11.74it/s]


Precision: 0.9969
Recall: 0.9969
F1 Score: 0.9969
Accuracy: 0.9969
Test loss: 0.00856858497913011
Epoch: 24/25 --- < Starting Time : Sat Dec 21 08:58:57 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:31<00:00,  3.74it/s]


Epoch24 loss : 
 pretrain loss : 0.014716844074428082


100%|██████████| 2161/2161 [03:05<00:00, 11.63it/s]


Precision: 0.9943
Recall: 0.9943
F1 Score: 0.9943
Accuracy: 0.9943
Test loss: 0.023546327718000432
Epoch: 25/25 --- < Starting Time : Sat Dec 21 09:40:37 2024 >
-------------------------------------------------------------


100%|██████████| 8644/8644 [38:26<00:00,  3.75it/s]


Epoch25 loss : 
 pretrain loss : 0.013754679821431637


100%|██████████| 2161/2161 [03:04<00:00, 11.69it/s]

Precision: 0.9978
Recall: 0.9978
F1 Score: 0.9978
Accuracy: 0.9978
Test loss: 0.0062979711272184765





In [12]:
#validation
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score


test_loss = 0
model_autoencoder.eval()

# Initialize lists to store true labels and predictions
all_labels = []
all_preds = []

with torch.no_grad():
    for batch in tqdm(val_loader):
        img, label = batch
        img, label = img.to(device), label.to(device)

        #image_features = model.encode_image(img)
        #text_features = model.encode_text(clip.tokenize(cathegories).to(device))

        logits = model_autoencoder(img)
        #probs = logits_img.softmax(dim=-1).cpu().numpy()
        loss = ce_loss(logits, label)

        test_loss += loss
        # Store the true labels and predictions
        preds = logits.argmax(dim=1).cpu().numpy()
        all_labels.extend(label.cpu().numpy())
        all_preds.extend(preds)
        
# Calculate precision, recall, F1 score, and accuracy
precision = precision_score(all_labels, all_preds, average='weighted')
recall = recall_score(all_labels, all_preds, average='weighted')
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")

print("Test loss:", loss.item())

100%|██████████| 2161/2161 [03:03<00:00, 11.79it/s]

Precision: 0.9978
Recall: 0.9978
F1 Score: 0.9978
Accuracy: 0.9978
Test loss: 1.664435512793716e-05



