In [1]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import functools
from torchvision import transforms
from utils.utils import get_strategy
from dataloaders.dataset_fetalhead import BaseDataSets_HC18
import albumentations as A
from torch.nn import functional as F
from torch.utils.data import DataLoader, ConcatDataset, Subset
from networks.vision_transformer import SwinUnet as ViM_seg
from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys
from config import get_config
from networks.unet import UNet as Unet2D
from networks.archs import UNext
from utils.losses import softmax_dice_loss, info_nce_loss

  check_for_updates()


In [2]:
def ConstraLoss_AvgProj(inputs, targets, ndf=64):
    m=nn.AdaptiveAvgPool2d(ndf)
    input_pro = m(inputs)
    input_pro = input_pro.view(inputs.size(0),inputs.size(1),-1) #N*C
    targets_pro = m(targets)
    targets_pro = targets_pro.view(targets.size(0),targets.size(1),-1)#N*C
    input_normal = nn.functional.normalize(input_pro,p=2,dim=1)
    targets_normal = nn.functional.normalize(targets_pro,p=2,dim=1)
    res = (input_normal - targets_normal)
    print(res.shape)
    res = res * res
    print(res.shape)
    loss = torch.mean(res)
    return loss

In [3]:
x1 = torch.rand(4,2,448,448)
x2 = torch.rand(4,2,448,448)
ConstraLoss_AvgProj(x1,x2)

torch.Size([4, 2, 4096])
torch.Size([4, 2, 4096])


tensor(0.0034)

In [2]:
model = UNext(in_chns=3, out_channels=2, img_size=448, num_classes=2)

In [3]:
model

UNext(
  (encoder1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder3): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ebn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (ebn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (ebn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm3): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
  (norm4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (dnorm3): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
  (dnorm4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (block1): ModuleList(
    (0): shiftedBlock(
      (drop_path): Identity()
      (norm2): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
      (mlp): shiftmlp(
        (fc1): Linear(in_features=160, out_features

In [None]:
input = torch.rand(1,3,448,448)
testout, emb_out = model(input)

In [36]:
for i in range(1):
    for j in range(5-1):
        print(f'i: {i}, j: {j}, j+args.labeled_bs: {j+1}')

i: 0, j: 0, j+args.labeled_bs: 1
i: 0, j: 1, j+args.labeled_bs: 2
i: 0, j: 2, j+args.labeled_bs: 3
i: 0, j: 3, j+args.labeled_bs: 4


In [39]:
feats1 = torch.randn(5, 50176)
feats2 = torch.randn(5, 50176)

In [56]:
cos_sim = F.cosine_similarity(feats1[:,None,:], feats1[None,:,:], dim=-1)
print(cos_sim)
# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
print(self_mask.shape)
cos_sim.masked_fill_(self_mask, -9e15)
# Find positive example -> batch_size//2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2-1, dims=0)
# pos_mask = self_mask.roll(shifts=cos_sim.shape[0], dims=0)
print(cos_sim.shape[0]//2-1)
print(pos_mask)
# InfoNCE loss
cos_sim = cos_sim / 0.07
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()
print(nll)

tensor([[ 1.0000e+00, -2.2207e-03,  3.6291e-03,  9.1015e-03, -1.0551e-03],
        [-2.2207e-03,  1.0000e+00, -3.5107e-03,  4.9809e-04, -5.8683e-04],
        [ 3.6291e-03, -3.5107e-03,  1.0000e+00, -9.2778e-03,  2.5771e-03],
        [ 9.1015e-03,  4.9809e-04, -9.2778e-03,  1.0000e+00,  5.3993e-04],
        [-1.0551e-03, -5.8683e-04,  2.5771e-03,  5.3993e-04,  1.0000e+00]])
torch.Size([5, 5])
1
tensor([[False, False, False, False,  True],
        [ True, False, False, False, False],
        [False,  True, False, False, False],
        [False, False,  True, False, False],
        [False, False, False,  True, False]])
tensor(1.4321)


In [2]:
def info_nce_loss_2(feats1,feats2):
    # Calculate cosine similarity
    # feats1_norm = F.normalize(feats1, dim=1)
    # feats2_norm = F.normalize(feats2, dim=1)
    # cos_sim = torch.matmul(feats1_norm, feats2_norm.T)

    cos_sim = F.cosine_similarity(feats1[:,None,:], feats2[None,:,:], dim=-1)
    cos_sim = cos_sim / 0.07
    pos = torch.diag(cos_sim, 0)
    # print(f'{feats1.shape}, and {feats2.shape}')
    # print(f'cos_sim {cos_sim.shape}')
    # Mask out cosine similarity to itself
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    print(f'self_mask {self_mask.shape}')
    cos_sim.masked_fill_(self_mask, -9e15)
    # Find positive example -> batch_size//2 away from the original example
    # pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
    # InfoNCE loss
    # nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    nll = -pos + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()
    return nll

In [3]:
class projectors(nn.Module):
    def __init__(self, input_nc=768, ndf=256, norm_layer=nn.BatchNorm2d):
        super(projectors, self).__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.conv_1 = nn.Linear(input_nc, ndf)
        # self.conv_2 = nn.Linear(ndf, ndf*2)
    def forward(self, input):
        x_out = self.conv_1(input)
        # x_out = self.pool(x_out)
        # x_out = self.conv_2(x_0)
        # x_out = self.pool(x_out)
        return x_out 

In [154]:
input_x = torch.randn(5, 64, 768)
input_y = torch.randn(5, 256, 8, 8)
# emb_x = torch.flatten(input_x, start_dim=1)
# emb_y = torch.flatten(input_y, start_dim=1)
emb_x = input_x
emb_y = input_y.view(5, 256, -1).permute(0, 2, 1)
print(emb_y.shape)

torch.Size([5, 64, 256])


In [155]:
proj_1 = projectors(input_nc=768, ndf=256)
# proj_2 = projectors(input_nc=16, ndf=64)
proj_x = proj_1(emb_x)
# proj_y = proj_2(emb_y)
print(proj_x.shape)
print(emb_y.shape)

torch.Size([5, 64, 256])
torch.Size([5, 64, 256])


In [4]:
a = torch.randn(5, 64)
b = torch.randn(5, 64)
c = torch.cat([a, b], dim=0)
d = torch.cat([b, a], dim=0)
print(c.shape)
test = F.cosine_similarity(c, d, dim=-1)

torch.Size([10, 64])


In [157]:
cos_sim_ab = F.cosine_similarity(a[:,None,:], b[None,:,:], dim=-1)
cos_sim_ab

tensor([[-0.1630,  0.0817,  0.1735, -0.1563,  0.0131],
        [ 0.0332,  0.1688, -0.1064, -0.1731,  0.0920],
        [ 0.0724,  0.0579,  0.0768,  0.0692,  0.1843],
        [ 0.0178, -0.0041,  0.3110, -0.0190, -0.0813],
        [-0.1418,  0.0030, -0.2234,  0.0946, -0.0430]])

In [5]:
info_nce_1 = info_nce_loss(a, b)
info_nce_1

tensor(1.6597)

In [6]:
info_nce_2 = info_nce_loss_2(a, b)
info_nce_2

self_mask torch.Size([5, 5])


tensor(0.9896)

In [167]:
a_norm = F.normalize(a, dim=1)            # Shape: (5, 64)
b_norm = F.normalize(b, dim=1)       # Shape: (64,)

cos_sim = torch.matmul(a_norm, b_norm.T)
cos_sim = cos_sim / 0.07
pos = torch.diagonal(cos_sim.clone(), 0)
cos_sim

tensor([[-2.3287,  1.1678,  2.4785, -2.2325,  0.1865],
        [ 0.4742,  2.4121, -1.5204, -2.4724,  1.3137],
        [ 1.0346,  0.8274,  1.0968,  0.9891,  2.6327],
        [ 0.2544, -0.0590,  4.4435, -0.2712, -1.1612],
        [-2.0257,  0.0431, -3.1919,  1.3515, -0.6142]])

In [168]:
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)
cos_sim

tensor([[-9.0000e+15,  1.1678e+00,  2.4785e+00, -2.2325e+00,  1.8655e-01],
        [ 4.7425e-01, -9.0000e+15, -1.5204e+00, -2.4724e+00,  1.3137e+00],
        [ 1.0346e+00,  8.2736e-01, -9.0000e+15,  9.8906e-01,  2.6327e+00],
        [ 2.5441e-01, -5.9048e-02,  4.4435e+00, -9.0000e+15, -1.1612e+00],
        [-2.0257e+00,  4.3135e-02, -3.1919e+00,  1.3515e+00, -9.0000e+15]])

### Test SAM

In [2]:
# from utils.prompt import get_bounding_box
from transformers import SamProcessor
from transformers import SamModel
from utils.losses import softmax_kl_loss

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [4]:
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

In [5]:
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
pretr_sam_model = SamModel.from_pretrained("facebook/sam-vit-base")
sam_checkpoint = '/mnt/storage/fangyijie/ft_sam/checkpoints_head/onlyreal/checkpoint_FTSAM_epoch_19_BS_1_TRAINSIZE_25_Real_20250417-230749'
if sam_checkpoint:
    model_state = torch.load(sam_checkpoint, map_location=torch.device("cuda:1"))
    pretr_sam_model.load_state_dict(model_state)
    pretr_sam_model.to(device)
    print(f"Load SAM model from {sam_checkpoint}")

# make sure to freeze SAM
for name, param in pretr_sam_model.named_parameters():
    param.requires_grad_(False)

Load SAM model from /mnt/storage/fangyijie/ft_sam/checkpoints_head/onlyreal/checkpoint_FTSAM_epoch_19_BS_1_TRAINSIZE_25_Real_20250417-230749


In [10]:
a_in = torch.randn(5, 3, 448, 448).to(device)
test = torch.argmax(a_in, dim=0, keepdim=True)
test.shape

torch.Size([1, 3, 448, 448])

In [7]:
def get_bounding_box(image=None, nobox=True):

    if nobox:
        # bbox = [0, 0, image.size[0], image.size[1]]
        bbox = [np.random.randint(0, 20),
                np.random.randint(0, 20),
                image.shape[0] - np.random.randint(0, 20),
                image.shape[1] - np.random.randint(0, 20)]

    return bbox

In [93]:
outputs = []
pretr_sam_model.eval()
for idx in range(1):
    image = a_in[idx].unsqueeze(0)
    image = torch.argmax(image, dim=0)
    image = image.permute(1, 2, 0)
    print(image.shape)
    print(image.shape[0])
    input_boxes = get_bounding_box(image, nobox = True)
    inputs = sam_processor(image.cpu().numpy(), input_boxes=[[input_boxes]], return_tensors="pt", do_rescale=False).to(torch.float32).to(device)
    # forward pass
    # note that the authors use `multimask_output=False` when performing inference
    output_logits = pretr_sam_model(**inputs, multimask_output=False).pred_masks.squeeze(0)
    print(output_logits.shape)
    outputs.append(output_logits)

torch.Size([448, 448, 3])
448
torch.Size([1, 1, 256, 256])


In [79]:
a_input = a_in[1, 0, :, :]
a_input = a_input[None,]

In [75]:
upsampled_outputs = F.interpolate(outputs.unsqueeze(1), size=(448, 448), mode='bilinear', align_corners=False)
upsampled_outputs.shape

torch.Size([1, 1, 448, 448])

In [86]:
kl_losses = []
for i in range(5):
    kl_div = softmax_kl_loss(a_input, upsampled_outputs.squeeze(1), sigmoid=True)
    # kl_div
    kl_losses.append(kl_div.item())

In [88]:
np.mean(kl_losses)

0.6986963748931885

In [69]:
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs)
medsam_seg_prob.shape

torch.Size([1, 256, 256])

In [59]:
input_log_softmax = torch.log(torch.sigmoid(outputs))
target_softmax = torch.sigmoid(b)

In [None]:
kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean')
kl_div

tensor(0.1014, device='cuda:1')

In [23]:
x = torch.randn(1, 448, 448)
test = []
for i in range(5):
    test.append(x)

torch.stack(test).size()
# torch.stack((x, y), dim=1).size()

torch.Size([5, 1, 448, 448])

In [10]:
a1 = torch.randn(5, 448, 448)
a2 = torch.randn(5, 448, 448)
kl_div = softmax_kl_loss(a1, a2, sigmoid=True)
kl_div

tensor(0.1037)

In [64]:
m = nn.Upsample(scale_factor=1.75, mode='nearest')
outputs_448 = m(outputs)
outputs_448.shape

torch.Size([1, 256, 448])

### Test Networks

In [5]:
swin_unet = SwinTransformerSys(img_size=448,
                patch_size=4,
                in_chans=3,
                num_classes=2,
                embed_dim=96,
                depths=[ 2, 2, 2, 2 ],
                num_heads=[ 3, 6, 12, 24 ],
                window_size=7,
                mlp_ratio=4.0,
                qkv_bias=True,
                qk_scale=False,
                drop_rate=0.0,
                drop_path_rate=0.2,
                ape=False,
                patch_norm=True,
                use_checkpoint=False)

SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:2
---final upsample expand_first---


In [3]:
input_x = torch.randn(1, 3, 448, 448)

In [19]:
_, emb = swin_unet.forward_features(input_x)
bottleneck = emb[-1]
print(bottleneck.shape)

torch.Size([1, 64, 768])


In [12]:
model1 = Unet2D(in_chns=3, class_num=2)

In [13]:
y_unet = model1(input_x)
print(y_unet.shape)

torch.Size([1, 2, 256, 256])


In [5]:
UNext_model = UNext(
    in_channels=3,
    out_channels=2,
    img_size=256,
    num_classes=2
)

In [7]:
y_UNext, emb = UNext_model(input_x)
print(y_UNext.shape)
print(emb.shape)

torch.Size([1, 2, 256, 256])
torch.Size([1, 256, 8, 8])


In [None]:
UNext_model

In [23]:
fc2 = nn.Linear(768, 256)

In [28]:
class NN_Network(nn.Module):
    def __init__(self,in_dim,hid):
        super(NN_Network, self).__init__()
        self.linear1 = nn.Linear(in_dim,hid)

    def forward(self, input_array):
        h = self.linear1(input_array)
        return h

In [29]:
nn_net = NN_Network(768, 256)
total_params = sum(p.numel() for p in nn_net.parameters())
total_params

196864

In [24]:
test = fc2(bottleneck)
print(test.shape)

torch.Size([1, 64, 256])


In [2]:
root_path = "/mnt/storage/fangyijie/HC18"

In [3]:
tr_transforms = A.Compose(
    [
        A.Rotate(limit=20, p=1.0),
        A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2),contrast_limit=(-0.5, 0.5)),
        A.Blur(blur_limit=(3, 3), p=0.3),
        A.GaussNoise(std_range=(0.05, 0.1), p=0.3),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.9),
        # A.Normalize(
        #     mean=[0.0, 0.0, 0.0],
        #     std=[1.0, 1.0, 1.0],
        #     max_pixel_value=255.0,
        # ),
        ],
)

In [4]:
tensor_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((448, 448), transforms.InterpolationMode.NEAREST),
])

In [5]:
db_train_unlabled = BaseDataSets_HC18(base_dir=root_path, split="train", 
                                      islabeled=False, 
                                      transform=None, default_transform=tensor_transforms)

In [6]:
# apply transforms
db_train_unlabled.transform = tr_transforms

In [7]:
db_train_unlabled = BaseDataSets_HC18(base_dir=root_path, split="train", 
                                      islabeled=False, 
                                      isAL=True,
                                      transform=None, default_transform=tensor_transforms)

FileNotFoundError: [Errno 2] No such file or directory: '/mnt/storage/fangyijie/HC18/ssl/train/labeled_data/None/data_al_2_None.list'

In [None]:
for i_data in db_train_unlabled:
    mask = i_data['label']
    print(i_data['idx'])
    print(i_data['name'])
    print(mask.shape)
    print(torch.argmax(mask))
    print(torch.argmin(mask))
    break

In [None]:
strategy = get_strategy("RandomSampling")(db_train_unlabled) # load strategy

In [None]:
query_idxs = strategy.query(5) #([500, 1, 128, 128])

In [None]:
query_idxs

In [None]:
labeled_slice = 5
al_iter = 2
strategy_name = "RandomSampling"
# strategy_name = None

In [None]:
db_train_labled = BaseDataSets_HC18(base_dir=root_path, split="train", num=labeled_slice, islabeled=True, 
                                        transform=None, default_transform=tensor_transforms)
db_train_labled_da = BaseDataSets_HC18(base_dir=root_path, split="train", num=labeled_slice, islabeled=True, 
                                        transform=tr_transforms, default_transform=tensor_transforms)

db_train = ConcatDataset([db_train_labled, db_train_labled_da])

if strategy_name is not None and al_iter > 1:
    db_train_al_labled = BaseDataSets_HC18(base_dir=root_path, split="train", num=labeled_slice, islabeled=True, 
                                            isAL=True, alNum=al_iter, stgrategy=strategy_name,
                                            transform=None, default_transform=tensor_transforms)
    db_train_al_labled_da = BaseDataSets_HC18(base_dir=root_path, split="train", num=labeled_slice, islabeled=True, 
                                                isAL=True, alNum=al_iter, stgrategy=strategy_name,
                                                transform=tr_transforms, default_transform=tensor_transforms)
    
    db_train = ConcatDataset([db_train, db_train_al_labled, db_train_al_labled_da])

    db_train_unlabled = BaseDataSets_HC18(base_dir=root_path, split="train", num=labeled_slice, islabeled=False, 
                                            isAL=True, alNum=al_iter, stgrategy=strategy_name,
                                            transform=None, default_transform=tensor_transforms)
else:
    db_train_unlabled = BaseDataSets_HC18(base_dir=root_path, split="train", islabeled=False, 
                                                transform=None, default_transform=tensor_transforms)

In [None]:
for i_data in db_train_unlabled:
    # print(i_data['idx'])
    # print(i_data['name'])
    # print(i_data['image'].shape)
    # print(i_data['label'].shape)
    test = np.unique(i_data['label'])
    if len(test)>1:
        print(test)
        print(i_data['name'])
    # break

In [None]:
for i_data in db_train:
    # print(i_data['idx'])
    # print(i_data['name'])
    # print(i_data['image'].shape)
    # print(i_data['label'].shape)
    test = np.unique(i_data['label'])
    if len(test)==1:
        print(test)
        print(i_data['name'])
    # break

In [None]:
print(f'total lenght of train data: {len(db_train)}')

In [None]:
print(f'total lenght of train data: {len(db_train_unlabled)}')

In [None]:
for i_data in db_train:
    # print(i_data['idx'])
    print(i_data['name'])
    print(i_data['image'].shape)
    print(i_data['label'].shape)
    # break