In [1]:
import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
from tqdm import tqdm
import numpy as np
import faiss

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import pcl.loader
import pcl.builder

import torch.nn.functional as F
import wandb

In [2]:
# 初始化老師模型
print("=> creating teacher model")
teacher_model = pcl.builder.MoCo(
    models.__dict__['resnet50'],
    dim=128, r=512, m=0.9, T=0.7, mlp=False)
# 替換分類頭
# teacher_model.fc = nn.Linear(128, 200).cuda()
teacher_model = teacher_model.cuda(0)
print(teacher_model)
print(f"Teacher model device: {next(teacher_model.parameters()).device}")

print("=> creating student model ")
model = pcl.builder.MoCo(
    models.__dict__['resnet50'],
    dim=128, r=512, m=0.9, T=0.7, mlp=False)
model = model.cuda(0)
print(f"student model device: {next(model.parameters()).device}")

=> creating teacher model
MoCo(
  (encoder_q): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample

In [3]:
criterion = nn.CrossEntropyLoss().cuda(0)
optimizer = torch.optim.SGD(model.parameters(), 0.005,
                            momentum=0.9,
                            weight_decay=1e-4)
                            

In [4]:
path = r"C:\Users\k3866\Documents\Datasets\tiny_imagenet\tiny-imagenet-200"
traindir = os.path.join(path, 'train')
valdir = os.path.join(path, 'val')
print("train dir path :",traindir)
print("val dir path :",valdir)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

train dir path : C:\Users\k3866\Documents\Datasets\tiny_imagenet\tiny-imagenet-200\train
val dir path : C:\Users\k3866\Documents\Datasets\tiny_imagenet\tiny-imagenet-200\val


In [5]:
augmentation = [
    transforms.RandomResizedCrop(64, scale=(0.2, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
]

In [6]:
eval_augmentation = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    normalize
])

In [7]:
train_dataset = pcl.loader.ImageFolderInstance(
    traindir,
    pcl.loader.TwoCropsTransform(transforms.Compose(augmentation)))
    
eval_dataset = pcl.loader.ImageFolderInstance(
    traindir,
    eval_augmentation)

In [8]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=True,
    num_workers=16, pin_memory=True, drop_last=True)

eval_loader = torch.utils.data.DataLoader(
    eval_dataset, batch_size=256, shuffle=False,
    num_workers=16, pin_memory=True)

In [9]:
# # 加載預訓練的老師模型權重
# checkpoint_path = "checkpoint.pth.tar"
checkpoint_path = r"C:\Users\k3866\Documents\PretrianedModel\Moco\checkpoint_0099.pth.tar"
checkpoint = torch.load(checkpoint_path, map_location="cpu",weights_only=True)
state_dict = checkpoint["state_dict"]

# 處理權重名稱
new_state_dict = {}
for k in list(state_dict.keys()):
    if k.startswith("encoder_q"):
        new_state_dict[f"encoder_q.{k[len('encoder_q.'):]}"] = state_dict[k]
    elif k.startswith("encoder_k"):
        new_state_dict[f"encoder_k.{k[len('encoder_k.'):]}"] = state_dict[k]

# 加載權重到老師模型
msg = teacher_model.load_state_dict(new_state_dict, strict=False)
print(f"Missing keys: {msg.missing_keys}")
print(f"Unexpected keys: {msg.unexpected_keys}")
teacher_model = teacher_model.cuda(0)
teacher_model.eval()

Missing keys: ['queue', 'queue_ptr']
Unexpected keys: []


MoCo(
  (encoder_q): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

In [11]:
def set_centroids(teacher_model, train_loader, num_classes, feature_dim):
    print('Computing Centroid via SSLCon model...')
    centroids = torch.zeros(num_classes, feature_dim).cuda()
    print("centroid :",centroids.shape)
    print(centroids)
    counts = torch.zeros(num_classes).cuda()
    print("counts :",counts.shape)
    print(counts)
    with torch.no_grad():
        progress_bar = tqdm(train_loader, desc="Computing Centroids")
        for idx, (images, labels) in enumerate(progress_bar):
            progress_bar.set_description(f"[{idx + 1}/{len(train_loader)}]")
            image = images[0].cuda()
            labels = labels.cuda()
            features = teacher_model.encoder_q(image)
            for i in range(num_classes):
                mask = labels == i
                # print(f"Mask (labels == {i}): {mask}")  # 打印該類別的 mask，表示哪些樣本屬於該類別
                centroids[i] += features[mask].sum(dim=0)
                # print(f"Updated centroids[{i}]: {centroids[i]}")  # 打印更新後的 centroids[i]
                counts[i] += mask.sum()
                # print(f"Updated counts[{i}]: {counts[i]}")  # 打印更新後的 counts[i]
    centroids /= counts.unsqueeze(1)
    # centroids = F.normalize(centroids, dim=-1)
    return centroids, counts

In [12]:
import torch
with torch.no_grad():
    for images, _ in train_loader:
        features = teacher_model.encoder_k(images[0].cuda())
        feature_dim = features.shape[1]
        print(f"Feature shape: {features.shape}")
        break
num_classes = len(train_loader.dataset.classes)
class_centroids, class_counts = set_centroids(teacher_model, train_loader, num_classes, feature_dim)
print(f"Class centroids computed. Centroids shape: {class_centroids.shape}")
# print(class_centroids)
# class_centroids = get_class_centroids(teacher_model, train_loader,gpu)
class_centroids = class_centroids.cuda(0)

Feature shape: torch.Size([256, 128])
Computing Centroid via SSLCon model...
centroid : torch.Size([200, 128])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
counts : torch.Size([200])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

[390/390]: 100%|██████████| 390/390 [01:49<00:00,  3.55it/s]

Class centroids computed. Centroids shape: torch.Size([200, 128])





### kmeans 計算

In [None]:
def cosine_similarity_matrix(z_q, class_centroids, eps=1e-8):
    # 正規化
    z_q_norm = F.normalize(z_q, p=2, dim=1,eps=eps)
    class_centroids_norm = F.normalize(class_centroids, p=2, dim=1,eps=eps)
    # 矩陣乘法計算餘弦相似度
    return torch.matmul(z_q_norm, class_centroids_norm.t())

def compute_features(eval_loader, model):
    print('Computing features...')
    model.eval()
    features = torch.zeros(len(eval_loader.dataset), 128).cuda()
    for i, (images, index) in enumerate(tqdm(eval_loader)):
        with torch.no_grad():
            images = images.cuda(non_blocking=True)
            feat = model(images, is_eval=True)
            features[index] = feat
    return features.cpu()



In [None]:
# def run_kmeans(x):
#     print('performing kmeans clustering')
#     results = {'im2cluster': [], 'centroids': [], 'density': []}
#     num_cluster = "1000,1500,2000"
#     num_cluster = num_cluster.split(',')
#     for seed, num_cluster in enumerate(num_cluster):
#         d = x.shape[1]
#         k = int(num_cluster)
#         clus = faiss.Clustering(d, k)
#         clus.verbose = True
#         clus.niter = 20
#         clus.nredo = 5
#         clus.seed = seed
#         clus.max_points_per_centroid = 200
#         clus.min_points_per_centroid = 10

#         res = faiss.StandardGpuResources()
#         cfg = faiss.GpuIndexFlatConfig()
#         cfg.useFloat16 = False
#         cfg.device = 0
#         index = faiss.GpuIndexFlatL2(res, d, cfg)

#         clus.train(x, index)

#         D, I = index.search(x, 1)
#         im2cluster = [int(n[0]) for n in I]

#         centroids = faiss.vector_to_array(clus.centroids).reshape(k, d)

#         Dcluster = [[] for c in range(k)]
#         for im, i in enumerate(im2cluster):
#             Dcluster[i].append(D[im][0])

#         density = np.zeros(k)
#         for i, dist in enumerate(Dcluster):
#             if len(dist) > 1:
#                 d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
#                 density[i] = d

#         dmax = density.max()
#         for i, dist in enumerate(Dcluster):
#             if len(dist) <= 1:
#                 density[i] = dmax

#         density = density.clip(np.percentile(density, 10), np.percentile(density, 90))
#         density = 0.2 * density / density.mean()

#         centroids = torch.Tensor(centroids).cuda()
#         centroids = nn.functional.normalize(centroids, p=2, dim=1)

#         im2cluster = torch.LongTensor(im2cluster).cuda()
#         density = torch.Tensor(density).cuda()

#         results['centroids'].append(centroids)
#         results['density'].append(density)
#         results['im2cluster'].append(im2cluster)

#     return results

In [15]:
def apply_masking(features, cluster_assignments):
    """
    根據 clustering assignments 和遮罩策略對特徵數據應用遮罩。
    - features: 特徵數據，形狀為 (N, D)，N 是樣本數，D 是特徵維度。
    - cluster_assignments: 每個樣本的分群結果。
    - args: 包含遮罩模式和相關參數的配置。

    返回:
    - masked_features: 經過遮罩的特徵數據。
    """
    from collections import defaultdict
    import scipy.spatial as ss
    import numpy as np

    # 初始化變數
    clusters = defaultdict(list)  # 每個 cluster 包含的樣本索引
    max_dis_list = []  # 儲存需要遮罩的樣本索引
    mask_mode = 'mask_farthest'
    dist_threshold= 0.3
    proportion = 0.1
    # 將數據根據 cluster_assignments 分組
    for idx, cluster_id in enumerate(cluster_assignments):
        clusters[cluster_id].append(idx)
    print(f"Number of clusters: {len(clusters)}")

    # 遍歷每個 cluster，根據遮罩策略篩選需要遮罩的樣本
    for cluster_id, indices in clusters.items():
        cluster_features = features[indices]  # 提取該 cluster 的特徵
        centroid = np.mean(cluster_features, axis=0)  # 計算該 cluster 的質心
        print(f"Cluster {cluster_id}: Centroid computed.")
        # print(f"Centroid (first 5 values): {centroid[:5]}")

        # 計算每個樣本與質心的歐氏距離
        distances = [ss.distance.euclidean(centroid, features[idx]) for idx in indices]
        print(f"Cluster {cluster_id}: Computed distances (first 5): {distances[:5]}")

        if mask_mode == 'mask_farthest':
            # 遮罩距離質心最遠的樣本
            max_idx = indices[np.argmax(distances)]
            max_dis_list.append(max_idx)
            print(f"Cluster {cluster_id}: Masking farthest sample (index {max_idx}).")
        elif mask_mode == 'mask_threshold':
            # 遮罩距離超過指定閾值的樣本
            for idx, dist in zip(indices, distances):
                if dist > dist_threshold:
                    max_dis_list.append(idx)
                    print(f"Cluster {cluster_id}: Masking sample (index {idx}) with distance {dist:.3f}.")
        elif mask_mode == 'mask_proportion':
            # 遮罩指定比例的最遠樣本
            num_to_mask = int(len(indices) * proportion)
            sorted_indices = sorted(zip(indices, distances), key=lambda x: x[1], reverse=True)
            max_dis_list.extend([x[0] for x in sorted_indices[:num_to_mask]])
            print(f"Cluster {cluster_id}: Masking {num_to_mask} farthest samples.")
    # 將被遮罩的樣本設為零向量
    masked_features = features.copy()
    print(f"Total masked samples: {len(max_dis_list)}")
    
    for idx in max_dis_list:
        print(f"Masking sample at index {idx}.")
        masked_features[idx] = 0.0
    print("Max dis list len : ",len(max_dis_list))
    # print("Max features:",masked_features)
    print("Shape : ",masked_features.shape)
    return masked_features


In [16]:
def run_kmeans(features):
    """
    執行 KMeans 聚類，並應用遮罩策略。
    - features: 特徵數據，形狀為 (N, D)。
    - args: 包含 KMeans 和遮罩相關配置的參數。

    返回:
    - results: 包含質心、密度和分群結果的字典。
    """
    print('Performing kmeans clustering with masking...')
    results = {'im2cluster': [], 'centroids': [], 'density': []}
    masked_features = features.clone()  # 初始化 masked_features
    num_cluster = "1000,1500,2000"
    num_cluster = num_cluster.split(',')
    mask_mode = 'mask_farthest'
    # dist_threshold= 0.3
    # proportion = 0.1
    for seed, num_cluster in enumerate(num_cluster):
    # for seed, num_cluster in enumerate(args.num_cluster):
        d = masked_features.shape[1]
        k = int(num_cluster)
        
        # 初始化 FAISS 聚類
        clus = faiss.Clustering(d, k)
        clus.verbose = True
        clus.niter = 20
        clus.nredo = 5
        clus.seed = seed
        clus.max_points_per_centroid = 200
        clus.min_points_per_centroid = 10

        res = faiss.StandardGpuResources()
        cfg = faiss.GpuIndexFlatConfig()
        cfg.useFloat16 = False
        cfg.device = 0
        index = faiss.GpuIndexFlatL2(res, d, cfg)

        # 執行聚類
        clus.train(masked_features.cpu().numpy(), index)

        # 搜索最近的聚類中心
        D, I = index.search(masked_features.cpu().numpy(), 1)
        im2cluster = [int(n[0]) for n in I]

        # 計算每個 cluster 的距離
        Dcluster = [[] for c in range(k)]
        for im, i in enumerate(im2cluster):
            Dcluster[i].append(D[im][0])

        density = np.zeros(k)
        for i, dist in enumerate(Dcluster):
            if len(dist) > 1:
                density_value = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
                density[i] = density_value

        dmax = density.max()
        for i, dist in enumerate(Dcluster):
            if len(dist) <= 1:
                density[i] = dmax

        density = density.clip(np.percentile(density, 10), np.percentile(density, 90))
        density = 0.2 * density / density.mean()

        # 添加遮罩檢查之前
        # print(f"Cluster assignments (first 10): {im2cluster[:10]}")  # 確認分群結果
        print(f"Applying masking with mode: {mask_mode}")  # 確認遮罩模式

        # 應用遮罩策略，更新 masked_features
        masked_features = apply_masking(
            masked_features.cpu().numpy(),
            cluster_assignments=im2cluster,
        )
        masked_features = torch.tensor(masked_features).cuda()  # 轉換為 PyTorch 張量
        # print("轉換為 PyTorch 張量的 masked_features",masked_features,"shape : ",masked_features.shape)
        # print(f"Masked features (first row): {masked_features[0].cpu().numpy()}")  # 確認遮罩效果
        
        # 更新聚類結果
        print("Cluster K X D",k,d )
        centroids = faiss.vector_to_array(clus.centroids).reshape(k, d)
        centroids = torch.tensor(centroids).cuda()
        centroids = nn.functional.normalize(centroids, p=2, dim=1)
        
        im2cluster = torch.LongTensor(im2cluster).cuda()
        density = torch.Tensor(density).cuda()

        results['density'].append(density)
        results['centroids'].append(centroids)
        results['im2cluster'].append(im2cluster)

    return results

In [18]:
def adjust_learning_rate(optimizer, epoch):
    lr = 0.005
    # if args.cos:
    lr *= 0.5 * (1. + math.cos(math.pi * epoch / 200))
    # else:
    #     for milestone in args.schedule:
    #         lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        # print("output shape:",output.shape)
        if output is None or target is None:
            raise ValueError("Output or target is None")
        # print(f"Output shape: {output.shape}, Target shape: {target.shape}")
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

### Process 紀錄的class

In [19]:
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [20]:
def knowledge_distillation_loss(teacher_probs, student_probs, temperature=1.0):
    """
    計算 KD (Knowledge Distillation) 損失
    - teacher_probs: Teacher 模型的概率分布 (經過 softmax)。
    - student_probs: Student 模型的概率分布 (經過 softmax)。
    - temperature: 蒸餾溫度，默認為 1.0。較高的溫度可以使概率分布更加平滑。
    """
    # 調整 Teacher 和 Student 的概率分布，加入溫度系數
    teacher_probs = F.softmax(teacher_probs / temperature, dim=1)
    student_probs = F.softmax(student_probs / temperature, dim=1)
    
    # 計算 KL 散度損失 (使用 PyTorch 的 kl_div 函數)
    loss = F.kl_div(
        input=torch.log(student_probs),  # Student 的對數概率分布
        target=teacher_probs,           # Teacher 的目標概率分布
        reduction='batchmean'           # 平均計算每個 batch 的損失
    )
    
    # 返回損失值，考慮溫度對梯度的影響
    return loss * (temperature ** 2)


In [None]:
def train(train_loader, model, teacher_model,criterion, optimizer, epoch,cluster_result,class_centroids):
    
# def train(train_loader, model, teacher_model,criterion, optimizer, epoch, args, cluster_result):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    total_losses = AverageMeter('TotalLoss', ':.4e')
    info_losses = AverageMeter('InfoNCE_Loss', ':.4e')
    proto_losses = AverageMeter('ProtoLoss', ':.4e')
    centroid_losses = AverageMeter('Centroid Loss', ':.4e')
    kd_losses = AverageMeter('KD Loss', ':.4e')

    acc_inst = AverageMeter('Acc@Inst', ':6.2f')
    acc_proto = AverageMeter('Acc@Proto', ':6.2f')

    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, total_losses, info_losses, proto_losses,centroid_losses,kd_losses, acc_inst, acc_proto],
        prefix="Epoch: [{}]".format(epoch))

    model.train()
    teacher_model.eval()
    end = time.time()

    epoch_iterator = tqdm(train_loader, desc=f"Epoch {epoch}", unit="batch")

    for i, (images, index) in enumerate(epoch_iterator):
        data_time.update(time.time() - end)
        
        images[0] = images[0].cuda(0, non_blocking=True)
        images[1] = images[1].cuda(0, non_blocking=True)

        output, target,output_proto, student_key,target_proto = model(im_q=images[0], im_k=images[1], cluster_result=cluster_result, index=index)
        
        print(f"Output: {output.shape}")
        print(f"target: {target.shape}")
        
        loss_info_value = criterion(output, target)
        info_losses.update(loss_info_value.item(), images[0].size(0))
        total_loss_value = loss_info_value
        
        # Knowledge Distillation 
        # soft targets
        # soft label
        student_probs = student_key
        print(f"student probs shape:{student_probs.shape}")
        with torch.no_grad():
            teacher_embeddings = teacher_model.encoder_q(images[0]).cuda(0)  # Teacher 使用 key encoder
            teacher_probs =teacher_embeddings
            print(f"teacher_probs shape:{teacher_probs.shape}")

        kd_loss = knowledge_distillation_loss(teacher_probs, student_probs, temperature=1)
        print(f"kd_loss: {kd_loss}")
        kd_losses.update(kd_loss.item(),images[0].size(0))
        
        total_loss_value += kd_loss

        # # 切換為評估模式，計算質心對齊損失
        model.eval()
        with torch.no_grad():
            z_q = model.encoder_q(images[0]).cuda(0)  # 明確使用 query 編碼器
            z_k = model.encoder_k(images[1]).cuda(0)  # 明確使用 key 編碼器
            # # 在此插入檢查代碼
            # print(f"Class centroids device: {class_centroids.device}")
            # print(f"z_q device: {z_q.device}")
            # print(f"z_k device: {z_k.device}")

        # # # 確保所有計算都在 GPU
        # sim_q = cosine_similarity(z_q.unsqueeze(1), class_centroids.unsqueeze(0), dim=2).cuda(args.gpu)
        # sim_k = cosine_similarity(z_k.unsqueeze(1), class_centroids.unsqueeze(0), dim=2).cuda(args.gpu)
        sim_q = cosine_similarity_matrix(z_q, class_centroids,eps=1e-8)
        sim_k = cosine_similarity_matrix(z_k, class_centroids,eps=1e-8)
        print(f"Sim_q Shape: {sim_q.shape}")
        # print(f"Sim_q:{sim_q}\n")
        print(f"Sim_k Shape: {sim_k.shape}")
        # print(f"Sim_k:{sim_k}\n")

        # # 再次檢查 sim_q 和 sim_k
        # print(f"sim_q device: {sim_q.device}")
        # print(f"sim_k device: {sim_k.device}")

        # # Normalize centroids on GPU
        # print(f"Before normalizing centroids: {class_centroids.device}")
        # class_centroids = F.normalize(class_centroids, p=2, dim=1).cuda(0)
        # print(f"After normalizing centroids: {class_centroids.shape}")
        
        # print(f"z_q device: {z_q.device}, class_centroids device: {class_centroids.device}")
        loss_centroid_value = 1 - F.cosine_similarity(sim_q, sim_k).mean()
        # print(f"F.cosine_similarity(sim_q, sim_k).mean(): {F.cosine_similarity(sim_q, sim_k).mean()}")
        # print(f"Loss centroid value calculated on device: {loss_centroid_value}")
        model.train()
        
        # # Update Centroid Loss
        total_loss_value += 0.5 * loss_centroid_value
        centroid_losses.update(loss_centroid_value.item(), images[0].size(0))
        if output_proto is not None:
            loss_proto_value = 0
            num_cluster = "1000,1500,2000"
            num_cluster = num_cluster.split(',')
            for proto_out, proto_target in zip(output_proto, target_proto):
                proto_target = proto_target.cuda(0, non_blocking=True)
                proto_loss = criterion(proto_out, proto_target)
                loss_proto_value += proto_loss
                accp = accuracy(proto_out, proto_target)[0]
                acc_proto.update(accp.item(), images[0].size(0))
            
            if len(num_cluster) > 0:
                loss_proto_value /= len(num_cluster)
                total_loss_value += loss_proto_value
                print(f"loss_proto_value: {loss_proto_value}, type: {type(loss_proto_value)}")
                proto_losses.update(loss_proto_value.item(), images[0].size(0))


        total_losses.update(total_loss_value.item(), images[0].size(0))
        acc_result = accuracy(output, target)
        # print(f"Accuracy result: {acc_result}")  # 應該打印非 None 值
        acc = acc_result[0]
        acc_inst.update(acc.item(), images[0].size(0))

        optimizer.zero_grad()
        total_loss_value.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if i % 100 == 0:
            progress.display(i)
    wandb.log({
        "epoch": epoch,
        "total_loss": total_losses.avg,
        "info_nce_loss": info_losses.avg,
        "proto_nce_loss": proto_losses.avg,
        "centroid_loss": centroid_losses.avg,
        "kd_loss":kd_losses.avg,
        "accuracy_inst": acc_inst.avg,
        "accuracy_proto": acc_proto.avg,
    })

In [22]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [23]:
for epoch in range(0,200):
    cluster_result = None
    if epoch >= 1:
        # cluster_result
        features = compute_features(eval_loader, model)
        cluster_result = run_kmeans(features)

    adjust_learning_rate(optimizer, epoch)
    train(train_loader, model,teacher_model, criterion, optimizer, epoch, cluster_result,class_centroids)
    # train(train_loader, model,teacher_model, criterion, optimizer, epoch, args, cluster_result)

    if (epoch + 1) % 5 == 0:
        save_checkpoint({
            'epoch': epoch + 1,
            'arch':'resnet',
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format('experiment_pcl', epoch))

Epoch 0:   0%|          | 0/390 [00:00<?, ?batch/s]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.5141429305076599
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   0%|          | 1/390 [00:52<5:37:48, 52.10s/batch]

Epoch: [0][  0/390]	Time 52.106 (52.106)	Data 44.394 (44.394)	TotalLoss 5.8364e+00 (5.8364e+00)	InfoNCE_Loss 5.1733e+00 (5.1733e+00)	ProtoLoss 0.0000e+00 (0.0000e+00)	Centroid Loss 2.9799e-01 (2.9799e-01)	KD Loss 5.1414e-01 (5.1414e-01)	Acc@Inst 100.00 (100.00)	Acc@Proto   0.00 (  0.00)
Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4925138056278229
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   1%|          | 2/390 [00:56<2:35:25, 24.04s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4587862491607666
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   1%|          | 3/390 [01:00<1:35:25, 14.79s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.431710422039032
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   1%|          | 4/390 [01:05<1:11:08, 11.06s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.40890052914619446
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   1%|▏         | 5/390 [01:10<56:49,  8.86s/batch]  

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3954560458660126
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   2%|▏         | 6/390 [01:15<47:37,  7.44s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.35766375064849854
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   2%|▏         | 7/390 [01:20<42:37,  6.68s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.35998016595840454
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   2%|▏         | 8/390 [01:24<38:00,  5.97s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3854227066040039
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   2%|▏         | 9/390 [01:29<36:01,  5.67s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.44222337007522583
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   3%|▎         | 10/390 [01:35<35:07,  5.55s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4019940197467804
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   3%|▎         | 11/390 [01:40<35:12,  5.57s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4421185255050659
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   3%|▎         | 12/390 [01:45<33:58,  5.39s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.45385879278182983
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   3%|▎         | 13/390 [01:51<33:44,  5.37s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4639788568019867
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   4%|▎         | 14/390 [01:56<33:31,  5.35s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4442075490951538
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   4%|▍         | 15/390 [02:01<33:15,  5.32s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4400923252105713
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   4%|▍         | 16/390 [02:05<31:20,  5.03s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4327610731124878
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   4%|▍         | 17/390 [02:11<31:29,  5.06s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4301336407661438
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   5%|▍         | 18/390 [02:16<31:14,  5.04s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.40454956889152527
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   5%|▍         | 19/390 [02:21<32:24,  5.24s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4079052209854126
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   5%|▌         | 20/390 [02:27<32:54,  5.34s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.37608852982521057
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   5%|▌         | 21/390 [02:31<31:19,  5.09s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3703717887401581
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   6%|▌         | 22/390 [02:37<32:22,  5.28s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.36726585030555725
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   6%|▌         | 23/390 [02:42<32:29,  5.31s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.39918723702430725
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   6%|▌         | 24/390 [02:48<33:12,  5.45s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.40758970379829407
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   6%|▋         | 25/390 [02:55<34:50,  5.73s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.39393407106399536
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   7%|▋         | 26/390 [02:59<32:52,  5.42s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.37990278005599976
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   7%|▋         | 27/390 [03:05<32:23,  5.35s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.38298311829566956
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   7%|▋         | 28/390 [03:11<34:23,  5.70s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4255527853965759
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   7%|▋         | 29/390 [03:16<33:26,  5.56s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3918708562850952
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   8%|▊         | 30/390 [03:22<33:34,  5.59s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3990691304206848
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   8%|▊         | 31/390 [03:27<33:23,  5.58s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3797306418418884
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   8%|▊         | 32/390 [03:33<32:54,  5.52s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.38976791501045227
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   8%|▊         | 33/390 [03:38<33:02,  5.55s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3813732862472534
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   9%|▊         | 34/390 [03:44<33:44,  5.69s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4056026339530945
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   9%|▉         | 35/390 [03:49<32:10,  5.44s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.37723106145858765
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   9%|▉         | 36/390 [03:55<31:47,  5.39s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3858591914176941
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:   9%|▉         | 37/390 [04:01<32:41,  5.56s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.37976494431495667
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  10%|▉         | 38/390 [04:07<33:33,  5.72s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.4020177125930786
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  10%|█         | 39/390 [04:13<33:41,  5.76s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3792133927345276
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  10%|█         | 40/390 [04:18<33:03,  5.67s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3900160789489746
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  11%|█         | 41/390 [04:23<31:33,  5.43s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3946141004562378
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  11%|█         | 42/390 [04:28<30:47,  5.31s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.36189591884613037
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  11%|█         | 43/390 [04:32<29:21,  5.08s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3862631916999817
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  11%|█▏        | 44/390 [04:37<28:32,  4.95s/batch]

Output: torch.Size([256, 513])
target: torch.Size([256])
student probs shape:torch.Size([256, 128])
teacher_probs shape:torch.Size([256, 128])
kd_loss: 0.3800046443939209
Sim_q Shape: torch.Size([256, 200])
Sim_k Shape: torch.Size([256, 200])
After normalizing centroids: torch.Size([200, 128])


Epoch 0:  12%|█▏        | 45/390 [04:47<36:44,  6.39s/batch]


KeyboardInterrupt: 