In [23]:
import os
import yaml
import argparse
import numpy as np

from glob import glob
from tqdm import tqdm
from TT_SFUDA_2D.dataset import Dataset
from TT_SFUDA_2D.metrics import iou_score
from collections import OrderedDict

from albumentations import RandomRotate90,Resize
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms as st_transforms

from TT_SFUDA_2D.losses import *
from TT_SFUDA_2D.archs import *
from TT_SFUDA_2D.utils import *
# from TT_SFUDA_2D import archs, losses

cudnn.benchmark = True

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# ============= NEW ACADEMIC IMPROVEMENTS =============

class MultiScaleConsistency:
    def __init__(self, scales=[0.75, 1.0, 1.25]):
        self.scales = scales
    
    def generate_multiscale_pseudo_labels(self, model, image):
        """Generate pseudo-labels at multiple scales for consistency"""
        pseudo_labels = []
        confidences = []
        
        h, w = image.shape[-2:]
        
        for scale in self.scales:
            new_h, new_w = int(h * scale), int(w * scale)
            if scale != 1.0:
                scaled_img = F.interpolate(image, size=(new_h, new_w), mode='bilinear', align_corners=False)
            else:
                scaled_img = image
            
            with torch.no_grad():
                pred = model(scaled_img)
                if scale != 1.0:
                    pred = F.interpolate(pred, size=(h, w), mode='bilinear', align_corners=False)
                
                prob = torch.sigmoid(pred)
                pseudo_labels.append(prob)
                
                # Confidence based on entropy
                entropy = -(prob * torch.log(prob + 1e-8) + (1-prob) * torch.log(1-prob + 1e-8))
                confidence = 1.0 - entropy / np.log(2)  # Normalize entropy
                confidences.append(confidence)
        
        # Weighted ensemble based on confidence
        weights = torch.stack(confidences)
        weights = F.softmax(weights, dim=0)
        
        final_pseudo = sum(w * pl for w, pl in zip(weights, pseudo_labels))
        final_confidence = weights.mean(dim=0)
        
        return final_pseudo, final_confidence

class MCDropoutUncertainty:
    def __init__(self, n_samples=5):  # Reduced for efficiency
        self.n_samples = n_samples
    
    def enable_dropout_inference(self, model):
        """Enable dropout during inference for uncertainty estimation"""
        for module in model.modules():
            if isinstance(module, nn.Dropout):
                module.train()
    
    def estimate_uncertainty(self, model, input_data):
        """Estimate epistemic uncertainty using MC-Dropout"""
        model.eval()
        self.enable_dropout_inference(model)
        
        predictions = []
        with torch.no_grad():
            for _ in range(self.n_samples):
                pred = torch.sigmoid(model(input_data))
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        mean_pred = predictions.mean(dim=0)
        uncertainty = predictions.var(dim=0)
        
        return mean_pred, uncertainty

class ProgressiveDomainAdaptation:
    def __init__(self, initial_threshold=0.9, final_threshold=0.6, total_epochs=100):
        self.initial_threshold = initial_threshold
        self.final_threshold = final_threshold
        self.total_epochs = total_epochs
    
    def get_confidence_threshold(self, epoch):
        """Gradually decrease confidence threshold"""
        progress = min(epoch / self.total_epochs, 1.0)
        threshold = self.initial_threshold - (self.initial_threshold - self.final_threshold) * progress
        return max(threshold, self.final_threshold)

def gaussian_kernel(x, y, gamma=1.0):
    """Gaussian RBF kernel for MMD"""
    x_size = x.size(0)
    y_size = y.size(0)
    dim = x.size(1)
    
    x = x.unsqueeze(1).expand(x_size, y_size, dim)
    y = y.unsqueeze(0).expand(x_size, y_size, dim)
    
    return torch.exp(-gamma * torch.pow((x - y), 2).sum(2))

def mmd_loss(source_features, target_features, gamma=1.0):
    """Maximum Mean Discrepancy loss for domain alignment"""
    xx = gaussian_kernel(source_features, source_features, gamma)
    yy = gaussian_kernel(target_features, target_features, gamma)
    xy = gaussian_kernel(source_features, target_features, gamma)
    
    return xx.mean() + yy.mean() - 2 * xy.mean()

def feature_alignment_loss(source_features, target_features, feature_layers=[1, 2, 3]):
    """Multi-layer feature alignment using MMD"""
    total_loss = 0
    for i in feature_layers:
        if i < len(source_features) and i < len(target_features):
            src_feat = source_features[i].view(source_features[i].size(0), -1)
            tgt_feat = target_features[i].view(target_features[i].size(0), -1)
            total_loss += mmd_loss(src_feat, tgt_feat)
    
    return total_loss / len(feature_layers) if feature_layers else 0

def confidence_weighted_loss(student_pred, teacher_pred, confidence, gamma=2.0):
    """Confidence-weighted loss with adaptive weighting"""
    # Normalize confidence to [0, 1]
    conf_norm = (confidence - confidence.min()) / (confidence.max() - confidence.min() + 1e-8)
    weight = conf_norm ** gamma
    
    mse_loss = (student_pred - teacher_pred) ** 2
    weighted_loss = weight * mse_loss
    return weighted_loss.mean()

# ============= ENHANCED EXISTING FUNCTIONS =============

def parse_args():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--source', default=None, help='model name')
    # parser.add_argument('--target', default=None, help='model name')
    # # Add new hyperparameters
    # parser.add_argument('--multiscale', action='store_true', help='Use multiscale consistency')
    # parser.add_argument('--mc_dropout', action='store_true', help='Use MC dropout uncertainty')
    # parser.add_argument('--progressive', action='store_true', help='Use progressive domain adaptation')
    # parser.add_argument('--feature_align', action='store_true', help='Use feature alignment')
    # args = parser.parse_args()
    args = {}
    args["source"] = 'chase'
    args["target"] = 'rite'
    args["mc_dropout"] = True
    args["multiscale"] = False
    args["progressive"] = False
    args["feature_align"] = False
    return args

def enhanced_build_pseduo_augmentation(img):
    """Enhanced augmentation with more diversity"""
    aug1 = st_transforms.ColorJitter(0.02, 0.02, 0.02, 0.01)
    aug2 = st_transforms.RandomGrayscale(p=1.0)
    aug3 = st_transforms.RandomSolarize(threshold=128.0, p=1.0)
    aug4 = st_transforms.RandomAutocontrast(p=1.0)
    aug5 = st_transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5))

    aug_img1 = aug1(img).unsqueeze(0)
    aug_img2 = aug2(img).unsqueeze(0)
    aug_img3 = aug2(img).unsqueeze(0)
    aug_img4 = aug4(img).unsqueeze(0)
    aug_img5 = aug5(img).unsqueeze(0)
    
    aug_data = torch.cat([img.unsqueeze(0), aug_img1, aug_img2, aug_img3, aug_img4, aug_img5], dim=0)
    return aug_data

def enhanced_uncert_voting(aug_output, confidence_threshold=0.7):
    """Enhanced uncertainty voting with better confidence estimation"""
    aug_all_prob = []
    aug_all_ent = []
    
    for i in range(1, len(aug_output)):
        prob = torch.sigmoid(aug_output[i])
        aug_all_prob.append(prob)
        entropy = -(prob * torch.log(prob + 1e-8) + (1-prob) * torch.log(1-prob + 1e-8))
        aug_all_ent.append(entropy)
    
    no_aug_prob = torch.sigmoid(aug_output[0])
    no_aug_entropy = -(no_aug_prob * torch.log(no_aug_prob + 1e-8) + 
                      (1-no_aug_prob) * torch.log(1-no_aug_prob + 1e-8))
    
    # Ensemble prediction
    if aug_all_prob:
        aug_prob_mean = sum(aug_all_prob) / len(aug_all_prob)
        aug_entropy_mean = sum(aug_all_ent) / len(aug_all_ent)
        
        # Combined confidence based on entropy
        total_entropy = (no_aug_entropy + aug_entropy_mean) / 2
        confidence_map = 1.0 - (total_entropy / np.log(2))  # Normalize
        
        # High confidence mask
        high_conf_mask = confidence_map > confidence_threshold
        
        # Dynamic thresholding
        ensemble_prob = (no_aug_prob + aug_prob_mean) / 2
        adaptive_thresh = ensemble_prob.mean() + 0.1 * ensemble_prob.std()
        adaptive_thresh = torch.clamp(adaptive_thresh, 0.3, 0.7)
        
        pseudo_label = (ensemble_prob > adaptive_thresh).float()
        pseudo_label = pseudo_label * high_conf_mask.float()
        
        return pseudo_label.unsqueeze(0), confidence_map.unsqueeze(0)
    else:
        # Fallback to original
        pseudo_label = dynamic_threshold_label(no_aug_prob)
        confidence = 1.0 - (no_aug_entropy / np.log(2))
        return pseudo_label.unsqueeze(0), confidence.unsqueeze(0)

def dynamic_threshold_label(prob, alpha=0.3, min_thresh=0.3, max_thresh=0.8):
    """Enhanced dynamic threshold with better adaptation"""
    dims = tuple(range(1, len(prob.shape)))
    mean_prob = prob.mean(dim=dims, keepdim=True)
    std_prob = prob.std(dim=dims, keepdim=True)
    
    thresh = mean_prob + alpha * std_prob
    thresh = torch.clamp(thresh, min=min_thresh, max=max_thresh)
    
    pseudo_label = (prob >= thresh).float()
    return pseudo_label

@torch.jit.script
def enhanced_sigmoid_entropy_loss(x: torch.Tensor, weight: float = 1.0) -> torch.Tensor:
    """Enhanced entropy loss with weighting"""
    entropy = -(x*torch.log(x + 1e-8) + (1-x)*torch.log(1-x + 1e-8))
    return weight * entropy.mean()

def enhanced_consistency_loss(msrc_feat, tgt_feat, weights=None):
    """Enhanced consistency loss with optional weighting"""
    if weights is None:
        weights = [1.0] * min(len(msrc_feat), len(tgt_feat))
    
    total_loss = 0
    loss_fn = nn.MSELoss()
    
    for i, weight in enumerate(weights):
        if i < len(msrc_feat) and i < len(tgt_feat):
            total_loss += weight * loss_fn(tgt_feat[i], msrc_feat[i])
    
    return total_loss / len(weights)

def enhanced_sfuda_target(config, train_loader, pseudo_model, msrc_model, criterion, optimizer, 
                         multiscale=None, mc_dropout=None, progressive=None, epoch=0):
    """Enhanced SFUDA target adaptation with academic improvements"""
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    pseudo_model.eval()
    msrc_model.train()
    pbar = tqdm(total=len(train_loader))

    # Progressive threshold
    confidence_thresh = 0.7
    if progressive:
        confidence_thresh = progressive.get_confidence_threshold(epoch)

    for input, target, path in train_loader:
        # Enhanced augmentation
        aug_input = enhanced_build_pseduo_augmentation(input.squeeze(0))
        
        with torch.no_grad():
            if multiscale:
                # Multi-scale pseudo-labeling
                ps_output, conf_map = multiscale.generate_multiscale_pseudo_labels(
                    pseudo_model, input.to(device)
                )
                ps_output = ps_output.unsqueeze(0)
                conf_map = conf_map.unsqueeze(0)
            elif mc_dropout:
                # MC-Dropout uncertainty
                mean_pred, uncertainty = mc_dropout.estimate_uncertainty(
                    pseudo_model, input.to(device)
                )
                ps_output = mean_pred
                conf_map = 1.0 / (1.0 + uncertainty)  # Inverse uncertainty as confidence
            else:
                # Enhanced voting
                aug_output = pseudo_model(aug_input.to(device))
                ps_output, conf_map = enhanced_uncert_voting(aug_output.detach(), confidence_thresh)

        optimizer.zero_grad()
        output = msrc_model(aug_input.to(device))
        
        # Enhanced losses
        if multiscale or mc_dropout:
            # Confidence-weighted loss
            seg_loss = confidence_weighted_loss(
                torch.sigmoid(output), 
                ps_output.repeat(aug_input.size(0), 1, 1, 1).to(device),
                conf_map.repeat(aug_input.size(0), 1, 1, 1).to(device)
            )
        else:
            seg_loss = criterion(output.to(device), ps_output.repeat(aug_input.size(0), 1, 1, 1).to(device))
        
        # Enhanced entropy loss with adaptive weighting
        ent_weight = 0.1 if epoch < config.get('stage1', 50) // 2 else 0.05
        ent_loss = enhanced_sigmoid_entropy_loss(torch.sigmoid(output), ent_weight)
        
        loss = seg_loss + ent_loss

        loss.backward()
        optimizer.step()

        iou, dice = iou_score(output, target.to(device))
        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
            ('conf_thresh', confidence_thresh if progressive else 0.7),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)
    
    pbar.close()
    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])

def enhanced_sfuda_task(train_loader, msrc_model, tgt_model, criterion, optimizer, 
                       feature_align=False, mc_dropout=None, epoch=0):
    """Enhanced SFUDA task adaptation with feature alignment"""
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}
    msrc_model.eval()
    tgt_model.train()
    pbar = tqdm(total=len(train_loader))

    for input, target, _ in train_loader:
        w_input = input.to(device)
        target = target.to(device)
        
        image_strong_aug = build_strong_augmentation(input.squeeze(0))
        s_input = image_strong_aug.unsqueeze(0).to(device)

        with torch.no_grad():
            if hasattr(msrc_model, 'forward') and 'mode' in msrc_model.forward.__code__.co_varnames:
                w_output, msrc_feat = msrc_model(w_input, mode='const')
            else:
                w_output = msrc_model(w_input)
                msrc_feat = []
            
            if mc_dropout:
                ps_output, uncertainty = mc_dropout.estimate_uncertainty(msrc_model, w_input)
                confidence = 1.0 / (1.0 + uncertainty)
            else:
                ps_output = torch.sigmoid(w_output).detach()
                # Simple confidence based on distance from 0.5
                confidence = 1.0 - 2.0 * torch.abs(ps_output - 0.5)

        optimizer.zero_grad()

        if hasattr(tgt_model, 'forward') and 'mode' in tgt_model.forward.__code__.co_varnames:
            output, tgt_feat = tgt_model(s_input, mode='const')
        else:
            output = tgt_model(s_input)
            tgt_feat = []

        # Enhanced segmentation loss
        if mc_dropout:
            seg_loss = confidence_weighted_loss(torch.sigmoid(output), ps_output, confidence)
        else:
            # Confidence masking
            conf_mask = confidence > 0.6
            masked_pred = torch.sigmoid(output) * conf_mask
            masked_target = ps_output * conf_mask
            seg_loss = F.mse_loss(masked_pred, masked_target)

        total_loss = seg_loss

        # Feature alignment loss
        if feature_align and msrc_feat and tgt_feat:
            align_loss = feature_alignment_loss(msrc_feat, tgt_feat)
            total_loss += 0.1 * align_loss

        # Enhanced consistency loss
        if msrc_feat and tgt_feat:
            const_loss = enhanced_consistency_loss(msrc_feat, tgt_feat)
            total_loss += 0.1 * const_loss

        total_loss.backward()
        optimizer.step()

        iou, dice = iou_score(output, target)
        avg_meters['loss'].update(total_loss.item(), input.size(0))
        avg_meters['iou'].update(iou, input.size(0))

        postfix = OrderedDict([
            ('loss', avg_meters['loss'].avg),
            ('iou', avg_meters['iou'].avg),
        ])
        pbar.set_postfix(postfix)
        pbar.update(1)

        # Enhanced teacher model update with adaptive rate
        update_rate = 0.999 if avg_meters['iou'].avg > 0.7 else 0.99
        new_msrc_dict = update_teacher_model(tgt_model, msrc_model, keep_rate=update_rate)
        msrc_model.load_state_dict(new_msrc_dict)
        
    pbar.close()
    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg)])

# ============= KEEP EXISTING FUNCTIONS =============

def build_strong_augmentation(img):
    """Enhanced strong augmentation"""
    augmentation = []
    augmentation.append(st_transforms.RandomApply([st_transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8))
    augmentation.append(st_transforms.RandomGrayscale(p=0.2))
    augmentation.append(st_transforms.RandomApply([st_transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5))
    strong_aug = st_transforms.Compose(augmentation)
    s_input = strong_aug(img)
    return s_input

@torch.no_grad()
def update_teacher_model(model_student, model_teacher, keep_rate=0.996):
    student_model_dict = model_student.state_dict()
    new_teacher_dict = OrderedDict()
    for key, value in model_teacher.state_dict().items():
        if key in student_model_dict.keys():
            new_teacher_dict[key] = (
                student_model_dict[key] * (1 - keep_rate) + value * keep_rate
            )
        else:
            raise Exception("{} is not found in student model".format(key))
    return new_teacher_dict

def validate(val_loader, model, criterion):
    avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter(), 'dice': AverageMeter()}
    model.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(val_loader))
        for input, target, meta in val_loader:
            input = input.to(device)
            target = target.to(device)
            output = model(input)
            loss = criterion(output, target)
            iou, dice = iou_score(output, target)
            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['iou'].update(iou, input.size(0))
            avg_meters['dice'].update(dice, input.size(0))
            postfix = OrderedDict([
                ('loss', avg_meters['loss'].avg),
                ('iou', avg_meters['iou'].avg),
                ('dice', avg_meters['dice'].avg)
            ])
            pbar.set_postfix(postfix)
            pbar.update(1)
        pbar.close()
    return OrderedDict([('loss', avg_meters['loss'].avg),
                        ('iou', avg_meters['iou'].avg),
                        ('dice', avg_meters['dice'].avg)])

def main():
    args = parse_args()

    config_file = "config_" + args["target"]
    with open('models/%s/%s.yml' % (args["source"], config_file), 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # Initialize academic improvements
    multiscale = MultiScaleConsistency() if args["multiscale"] else None
    mc_dropout = MCDropoutUncertainty() if args["mc_dropout"] else None
    progressive = ProgressiveDomainAdaptation(
        total_epochs=config.get('stage1', 50)
    ) if args["progressive"] else None

    print("Academic improvements enabled:")
    # Data loading (unchanged)
    train_img_ids = glob(os.path.join('inputs', args["target"], 'train','images', '*' + config['img_ext']))
    train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_img_ids]

    val_img_ids = glob(os.path.join('inputs', args["target"], 'test','images', '*' + config['img_ext']))
    val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in val_img_ids]

    train_transform = Compose([
        RandomRotate90(),
        transforms.Flip(),
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('inputs', args["target"], 'train','images'),
        mask_dir=os.path.join('inputs', args["target"], 'train','masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True,
        num_workers=config['num_workers'], drop_last=True)

    val_transform = Compose([
        Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', args["target"],'test', 'images'),
        mask_dir=os.path.join('inputs', args["target"],'test', 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, shuffle=False,
        num_workers=config['num_workers'], drop_last=False)

    # Model setup (unchanged)
    print("Creating model %s...!!!" % config['arch'])
    print("Loading source trained model...!!!")
    
    msrc_model = archs.__dict__[config['arch']](config['num_classes'],
                                               config['input_channels'],
                                               config['deep_supervision'])
    msrc_model.load_state_dict(torch.load('models/%s/model.pth'%config['name'], map_location="mps"))
    msrc_model.to(device)
    msrc_model.train()
    print("Successfully loaded source trained model...!!!")

    tgt_model = archs.__dict__[config['arch']](config['num_classes'],
                                              config['input_channels'],
                                              config['deep_supervision'])
    tgt_model.to(device)
    tgt_model.train()

    src_params = filter(lambda p: p.requires_grad, msrc_model.parameters())
    src_optimizer = optim.Adam(src_params, lr=config['lr'], weight_decay=config['weight_decay'])

    tgt_params = filter(lambda p: p.requires_grad, tgt_model.parameters())
    tgt_optimizer = optim.Adam(tgt_params, lr=config['lr'], weight_decay=config['weight_decay'])

    for c in range(config['num_classes']):
        os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
    
    pseudo_model = archs.__dict__[config['arch']](config['num_classes'],
                                                 config['input_channels'],
                                                 config['deep_supervision'])
    pretrained_dict = msrc_model.state_dict()
    pseudo_model.load_state_dict(pretrained_dict)
    pseudo_model.to(device)
    pseudo_model.eval()

    criterion = losses.__dict__[config['loss']]().to(device)
    
    print("\nTarget specific adaptation with enhancements...!!!")
    best_iou = 0.0
    for epoch in range(config['stage1']):
        train_log = enhanced_sfuda_target(
            config, train_loader, pseudo_model, msrc_model, criterion, src_optimizer,
            multiscale=multiscale, mc_dropout=mc_dropout, progressive=progressive, epoch=epoch
        )
        
        print('Epoch %d - train_loss %.4f - train_iou %.4f' % 
              (epoch, train_log['loss'], train_log['iou']))
        
        # Validation every 10 epochs
        if epoch % 10 == 0 or epoch == config['stage1'] - 1:
            val_log = validate(val_loader, msrc_model, criterion)
            print('Validation - dice: %.4f' % val_log['dice'])
            
            if val_log['dice'] > best_iou:
                best_iou = val_log['dice']
                torch.save(msrc_model.state_dict(), 
                          f'models/{config["name"]}/best_stage1_model.pth')

    msrc_model.eval()
    pretrained_dict = msrc_model.state_dict()
    tgt_model.load_state_dict(pretrained_dict)
    tgt_model.to(device)
    tgt_model.train()

    print("\nTask specific adaptation with enhancements...!!!")
    best_dice = 0.0
    for epoch in range(config['stage2']):
        train_log = enhanced_sfuda_task(
            train_loader, msrc_model, tgt_model, criterion, tgt_optimizer,
            feature_align=args.feature_align, mc_dropout=mc_dropout, epoch=epoch
        )
        
        print('Epoch %d - train_loss %.4f - train_iou %.4f' % 
              (epoch, train_log['loss'], train_log['iou']))
        
        # Validation every 10 epochs
        if epoch % 10 == 0 or epoch == config['stage2'] - 1:
            val_log = validate(val_loader, tgt_model, criterion)
            print('Validation - dice: %.4f' % val_log['dice'])
            
            if val_log['dice'] > best_dice:
                best_dice = val_log['dice']
                torch.save(tgt_model.state_dict(), 
                          f'models/{config["name"]}/best_final_model.pth')
    
    print("\nPerforming final adapted target model evaluation...!!!")
    val_log = validate(val_loader, tgt_model, criterion)
    print('Final adapted target model dice: %.4f' % val_log['dice'])
    print('Best dice achieved: %.4f' % best_dice)

In [24]:
main()

Academic improvements enabled:
Creating model UNet...!!!
Loading source trained model...!!!
Successfully loaded source trained model...!!!

Target specific adaptation with enhancements...!!!


 65%|██████▌   | 13/20 [00:09<00:04,  1.71it/s, loss=0.000139, iou=0.0194, conf_thresh=0.7]

KeyboardInterrupt: 