In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import sys
import os
from utils import *
from config_setting import setting_config_multitask as config
from datasets import Skin_Dataset
from model.BDFormer import BDFormer

In [None]:
train_dataset = Skin_Dataset(config, split="train")
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=config.num_workers)

In [None]:
set_seed(config.seed)
gpu_ids = [0]
torch.cuda.empty_cache()

model = BDFormer(img_size=256, in_channels=3, num_classes=config.num_classes, window_size=8)
model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])

for param in model.parameters():
    param.requires_grad = False

for name, layer in model.module.multi_task_MaxViT.named_children():
    if name in ['backbone']:
        for param in layer.parameters():
            param.requires_grad = True

In [None]:
sys.path.append(config.work_dir + '/')
checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
checkpoint_path = os.path.join(checkpoint_dir, 'latest.pth')
latest_state = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=False)

model.module.load_state_dict(latest_state['model_state_dict'], strict=False)

In [None]:
total_sim = 0
num_batches = 5
bce_dice_loss = BceDiceLoss()
ce_loss = CrossEntropyLoss()
ce_loss_contour = CrossEntropyLoss(weight=torch.tensor([1.0, 55.0]).cuda())
dice_loss_contour = DiceLoss_Contour(n_classes=2)

model.train()

for iter, data in enumerate(train_loader):
    if iter >= num_batches: break

    target_img, target_seg, target_contour, target_class = data
    target_img = target_img.cuda(non_blocking=True).float()
    target_seg = target_seg.cuda(non_blocking=True).float()
    target_contour = target_contour.cuda(non_blocking=True).float()
    target_contour = target_contour.squeeze(dim=1)
    target_class = target_class.cuda(non_blocking=True).long()

    # model.zero_grad()
    # pred_seg, _, _ = model(target_img)
    # loss_seg = bce_dice_loss(pred_seg, target_seg)
    # loss_seg.backward(retain_graph=True)
    # grad_seg = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])

    model.zero_grad()
    _, pred_contour, _ = model(target_img)
    pred_contour_soft = F.softmax(pred_contour, dim=1)
    loss_contour_ce = ce_loss_contour(pred_contour, target_contour.long())
    loss_contour_dice = dice_loss_contour(pred_contour, target_contour)
    loss_contour_mse = F.mse_loss(pred_contour_soft[:, 1, :, :], target_contour.to(torch.float32))
    loss_contour = 0.4 * loss_contour_ce + 0.2 * loss_contour_dice + 0.4 * loss_contour_mse
    loss_contour.backward(retain_graph=True)
    grad_contour = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])

    model.zero_grad()
    _, _, pred_class = model(target_img)
    loss_cls = ce_loss(pred_class, target_class)
    loss_cls.backward(retain_graph=True)
    grad_cls = torch.cat([p.grad.view(-1) for p in model.parameters() if p.grad is not None])

    sim = F.cosine_similarity(grad_contour.unsqueeze(0), grad_cls.unsqueeze(0))
    total_sim += sim.item()

print(f'Average Gradient Cosine Similarity: {total_sim / num_batches:.4f}')