In [None]:
from Utils.Utils import *
from Utils.Blacksmith import * 

from Utils.HyMNet import HyMNet
from timm.models.layers import trunc_normal_
import Utils.ViT as vit 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Seed
set_seed(0)

In [None]:
PATH = "/home/baharoon/HTN/data/"
CSV_PATH = {"HTNPath": PATH + r"HTN", "NonHTNPath": PATH + "NonHTN"}

MODELS_PATH = "/home/baharoon/HTN/Models"

os.makedirs(MODELS_PATH, exist_ok=True)

In [None]:
BATCH_SIZE = 16

train_transform = T.Compose([
    T.Resize((586, 586)),
    T.CenterCrop(512),
    T.ToTensor(),
    T.RandomHorizontalFlip(0.5),
    T.RandomRotation(degrees=(0, 360)),
    T.GaussianBlur(3),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = T.Compose([
    T.Resize((586, 586)),
    T.CenterCrop(512),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# train_transform = T.Compose([
#     T.Resize((256, 256)),
#     T.CenterCrop(224),
#     T.ToTensor(),
#     T.RandomHorizontalFlip(0.5),
#     T.RandomRotation(degrees=(0, 360)),
#     T.GaussianBlur(3),
#     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# test_transform = T.Compose([
#     T.Resize((256, 256)),
#     T.CenterCrop(224),
#     T.ToTensor(),
#     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

train_dataset = HypertensionDataset(CSV_PATH, split="train", train_transform=train_transform)
val_dataset = HypertensionDataset(CSV_PATH, split="val", test_transform=train_transform)
test_dataset = HypertensionDataset(CSV_PATH, split="test", test_transform=test_transform)

train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# ImageModel

In [None]:
results = {}

In [None]:
lrs = [1e-6]
criterion = nn.BCEWithLogitsLoss()

In [None]:
for lr in lrs:
     
    image_model = get_retfound("/home/baharoon/HTN/RETFound_cfp_weights.pth", classes=1, image_size=512).requires_grad_(True)
    
    model = HyMNet(image_model=image_model)

    all_params = dict(model.image_model.named_parameters())
    head_params = dict(model.image_model.head.named_parameters())

    del all_params['head.weight']
    del all_params['head.bias']

    # Create the optimizer
    optimizer = torch.optim.AdamW([
        {"params": head_params.values(), "lr": 0.005},
        {"params": all_params.values(), "lr": lr}
    ])

    # Parallelize model to multiple GPUs
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    # Send the model to GPU
    model.to(device)

    epochs = 50 
    epoch_length = math.ceil(len(train_dataset) / BATCH_SIZE)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epoch_length * epochs, eta_min=0)
    metrics, bm = train_val(epochs=epochs, model=model, criterion=criterion, optimizer=optimizer, train_loader=train_loader,
                        val_loader=test_loader, scheduler=scheduler, device=device, save_model=True)
    results[lr] = max([metrics[1][i]["AUROC"] for i in metrics[1]])

In [None]:
with open("/home/baharoon/HTN/HyMNet/Results/retfound_224.json", 'w') as f:
    # Use json.dump to write the dictionary to the file
    json.dump(results, f)

In [None]:
torch.save(bm, MODELS_PATH + r'/Retfound.pth')