# 1. Data/Model settings

In [1]:
from visualizer import *
from visualizer_supcon import *

  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
opt = parse_option()

In [3]:
opt.dataset = 'waterbirds'
# opt.tl_method = "linear_probing" # For zeroshot
opt.tl_method = "contrastive_adapter"
opt.train_target = "class"
non_target = "spurious"

opt.text_embedding_dir = f"/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/{opt.dataset}/clip_{opt.train_target}.json"
opt.text_spurious_embedding_dir = f"/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/{opt.dataset}/clip_{non_target}.json"
opt.text_group_embedding_dir = f"/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/{opt.dataset}/clip_group.json"
opt.image_embedding_dir = f"/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json"
opt.data_dir="/home/jinsu/workstation/project/debiasing-multi-modal/data/waterbirds/waterbird_complete95_forest2water2"

In [4]:
(trainset, train_loader, val_loader, test_loader, 
 get_yp_func, train_group_ratio, classifier, criterion) = initialize_for_vis(opt)

ce_loss = criterion

> Start Transfer Learning using [contrastive_adapter]
Load image embedding of Waterbirds: /home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json
/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json
ㄴ Corresponding text embedding of Waterbirds: /home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/clip_class.json
Load Data Loader (train, validation, test)
/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json
/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json
/home/jinsu/workstation/project/debiasing-multi-modal/data/embeddings_unnormalized/waterbirds/RN50/clip.json
Training target : class (Land bird(0) / Water bird(1))
Off-the-shelf classifier : Contrastive Adapter


# 2. Training & Visualize

## 2.1. Training Settings

In [5]:
opt.balance_by_zs_pred = False  # Anchor' class 밸런스.
opt.re_shuffle_ca_loader = True # 매 에폭마다 Contrastive Batch 셔플
opt.maintain_alternative_ordering = True # Anchor 배치를 클래스끼리 번갈아가면서 할당. (opt.balance_by_zs_pred와 동시에 사용)

opt.correct_class_bias = False # CA paper 내 CE-loader-psampling 방식으로 인해 심해지는 class imbalance 교정.
opt.reweighting_by_class = False # CE-loader에서 class balance를 1:1로 할당.

## 
opt.epochs = 5
opt.learning_rate = 1e-3
opt.batch_size = 128
opt.num_anchor = 1
opt.num_positive = 2048
opt.num_negative = 2048

opt.print_freq_ca = 1
opt.print_freq = 1
opt.batch_factor = 32 # CA Loader 내 배치팩터. (1 update per 32 anchor)
opt.contrastive_weight = 1.0

opt.ca_update = 10000 # ca update 멈추는 배치 개수 (1~50)
opt.ce_update = 10000 # ce update 멈추는 배치 개수 (1~12)
opt.ca_pre_norm = True # CLIP -> "Normalized CLIP" -> Adapter -> ...

## 2.2. Visualization Settings

In [12]:
from queue import Queue
opt.ca_update = 1 # CA 업데이트 횟수
opt.ce_update = 0 # CE 업데이트 횟수

# 임베딩 누적시킬 에폭 길이. 
opt.max_length_ebd_queue = 4 # (e.g., at 20 epoch, we are having following embeddings: 18-after-CA, 18-after-CE, 19-after-CA, 19-after-CE)

vis_handler = VisHandler(opt)

vis_handler.SaveTextEmbeddings(opt.text_embedding_dir) # class 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)
vis_handler.SaveTextEmbeddings(opt.text_spurious_embedding_dir) # spurious 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)
# vis_handler.SaveTextEmbeddings(opt.text_group_embedding_dir) # group 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)

num_nn_text_ebd = 10 # 각각의 Text embedding에서 제일 가까운 [num_nn_text_ebd] 개수의 이미지 임베딩을 뽑아, 평균낸 임베딩을 해당 Text embedding의 visualization에 사용함. 
set_bbox=False # True: 가독성 그나마 좋아지나, 가끔 가릴 때 있음.

In [7]:
def append_like_queue(array, new_instance, max_length):
    while (len(array) >= max_length):
        del array[0]
    
    array.append(new_instance)
    
# 길이 N인 인 큐 생성
train_ebd_q = []
val_ebd_q = []
test_ebd_q = []

In [109]:
class VisHandler():
    """
    - 1) Embedding 기반 Loader 받아서 Zero-shot prediction, Embeddings reduction 등 수행.
      - 1.1) Adapter 학습 하고 난 다음의 pth도 저장해놨으니, 불러와서 임베딩 똑같이 뽑을 수 있음.
      - 1.2) Linear probing / Adapter 등 학습 시 results/...에 결과 저장됨. 
    """
    def __init__(self, args):
        """
        Initialized by arguments for "single run"
        """
        self.args = args
        self.device = (torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu'))
    
        # self.train_results = EasyDict(read_pickle_file(os.path.join(self.run_path, 'full_dict.pickle')))
        self.final_results = {} # Best Train / Val / Text
        
        if self.args.dataset == "waterbirds" :
            self.legend_labels_dict = {"target": {0: "Landbird", 1: "Waterbird"}, "spurious": {0: "Land-background", 1:"Water-background"}, 
                            "group": {0: "Landbird on Land-background", 1: "Landbird on Water-background",
                                        2: "Waterbird on Land-background", 3: "Waterbird on Water-background"},
                            "prediction": {0: "Pred. to Landbird",
                                            1: "Pred. to Waterbird"}}

        self.model = None
        self.text_embeddings = []
        self.group_wise_stat_ebd = {}
        self.group_wise_stat_conf = {}
        
        self.epoch = 0
    
    def SaveWaterbirdsDatasets(self, trainset):
        self.train_set = trainset
    
    def SaveWaterbirdsLoaders(self, train_loader, val_loader, test_loader):
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def SaveTextEmbeddings(self, embedding_dir):
        self.text_embeddings.extend(get_text_embedding(embedding_dir, return_key=True))
    
    def SaveModel(self, classifier):
        self.classifier = classifier
        
    def SaveUtils(self, criterion, get_yp_func, train_group_ratio):
        self.criterion = criterion
        self.get_yp_func = get_yp_func
        self.train_group_ratio = train_group_ratio
        
    def SaveZeroShotResults(self, train_loader, val_loader, test_loader):
        self.zs_results = {}
        _, _, train_group_acc = validate_zs(self.args, train_loader, self.classifier, self.criterion, self.get_yp_func, self.train_group_ratio, target="class", print_label='Get ZS Acc. of train (class)')    
        _, _, val_group_acc = validate_zs(self.args, val_loader, self.classifier, self.criterion, self.get_yp_func, self.train_group_ratio, target="class", print_label='Get ZS Acc. of val (class)')    
        _, _, test_group_acc = validate_zs(self.args, test_loader, self.classifier, self.criterion, self.get_yp_func, self.train_group_ratio, target="class", print_label='Get ZS Acc. of test (class)')    
        self.zs_results['train'] = train_group_acc
        self.zs_results['val'] = val_group_acc
        self.zs_results['test'] = test_group_acc
    
    def GetEmbeddings(self, dataloader):
        # # NOTE Adapter 학습 이후 모델 받아서 추출하는 라인 추가해야함.
        
        total_embeddings = []

        total_labels = []
        total_spuriouss = []
        total_groups = []
        # total_confidences = []
        total_predictions = [] # Zero-shot

        print('> Saving activations')

        with torch.no_grad():
            for i, data in enumerate(tqdm(dataloader, desc='Running inference')):
                embeddings, labels_dict, _= data
                labels = labels_dict["class"]
                groups = labels_dict["group"]
                places = labels_dict["spurious"]
                predicted = labels_dict["ebd_y_pred"]

                total_labels.extend(labels.numpy())
                total_groups.extend(groups.numpy())
                total_spuriouss.extend(places.numpy())
                total_predictions.extend(predicted.numpy())
                total_embeddings.extend(embeddings.numpy())
                
                del embeddings; del labels; del groups; del places; del predicted

        total_embeddings = np.array(total_embeddings) # (# of full data, feat_dim)

        total_meta_results = {"targets" : total_labels, "spuriouss": total_spuriouss, "groups" : total_groups, 
                             "predictions": total_predictions}
        
        return total_embeddings, total_meta_results
        
    def VisRep(self, model, dataloader, vis_on, label_types=['group', 'target', 'spurious', 'prediction'], num_data=None, reduced_dim=2,
                          figsize=(8, 6), save=True, ftype='.png', title_suffix=None, save_id_suffix=None,
                          annotate_points=None, plot_mds=False, seed=42):
        """_summary_

        Args:
            model (_type_): nn.Module
            dataloader (_type_): Dataset for visualization 
            vis_on (_type_): choice <- ["train", "val", "test", "val_fg", "test_fg"] (correspond to dataloader)
            label_types (_type_): ['confidence', 'target', 'spurious', 'group', 'prediction']
            num_data (_type_, optional): for Random sampling(No..) Defaults to None=Full..
            reduced_dim (int, optional): _description_. Defaults to 2.
        """
        
        total_embeddings, total_meta_results = self.GetEmbeddings(model, dataloader)
        
        if self.args.tl_method == "linear_probing":
            model_title = "CLIP ZS"
            title_suffix= f'([{model_title}] Rep. on [{vis_on}])'
        else:
            title_suffix= f'([{self.args.tl_method}] Rep. on [{vis_on}] (Epoch {self.model_epoch}))'
            
             

        print(f'total_embeddings.shape: {total_embeddings.shape}')
        n_mult = 1
        pbar = tqdm(total=n_mult * len(label_types))
        for label_type in label_types:
            # For now just save both classifier ReLU activation layers (for MLP, BaseCNN)
            if save_id_suffix is not None:
                save_id = f'{reduced_dim}d_{label_type}_{vis_on}_{save_id_suffix}'
            else:
                save_id = f'{reduced_dim}d_{label_type}_{vis_on}'
                
            plot_umap(total_embeddings, total_meta_results, label_type, self.legend_labels_dict, reduced_dim, num_data, method='umap',
                        offset=0, figsize=figsize, save_id=save_id, save=save,
                        ftype=ftype, title_suffix=title_suffix, annotate_points=annotate_points,
                        seed=seed, display_image = True)
            # Add MDS
            if plot_mds:
                plot_umap(total_embeddings, total_meta_results, label_type, self.legend_labels_dict,  reduced_dim, num_data, method='mds',
                            offset=0, figsize=figsize, save_id=save_id, save=save,
                            ftype=ftype, title_suffix=title_suffix, annotate_points=annotate_points,
                            seed=seed, display_image = True)
            pbar.update(1)
    
    def VisRepAll(self, train_loader, val_loader, test_loader, label_types=['group', 'target', 'spurious', 'prediction'], num_data=None, reduced_dim=2,
                          figsize=(24, 6), save=True, ftype='.png', title_suffix=None, save_id_suffix=None,
                          annotate_points=None, plot_mds=False, seed=42, text_ebd=None, group_mean_ebd=None, num_nn_text_ebd=10, set_bbox=False):
        """
        - Projection all train/val/test sets to same sub-space. (thus same umap-structure)
        """
        
        # self.embeddings_df = pd.read_json(self.embedding_dir) # key : image_filename
        indices_to_convert = ['y', 'place', 'group', 'y_pred', 'split'] # str -> int
        # self.embeddings_df.loc[indices_to_convert] = self.embeddings_df.loc[indices_to_convert].astype('int64')
        
        total_embeddings_train, total_meta_results_train = self.GetEmbeddings(train_loader)
        total_embeddings_val, total_meta_results_val = self.GetEmbeddings(val_loader)
        total_embeddings_test, total_meta_results_test = self.GetEmbeddings(test_loader)
        
        # Save Group-wise Statistics -> [(norm of mean_vector, mean-vector) / compactness] for [train/val/test]
        
        print("> Calculating [Group-wise] Statistics...")
        self.group_wise_stat_ebd['train'] = GetGroupWiseStatEbd(total_embeddings_train, np.array(total_meta_results_train["groups"]))
        self.group_wise_stat_ebd['val'] = GetGroupWiseStatEbd(total_embeddings_val, np.array(total_meta_results_val["groups"]))
        self.group_wise_stat_ebd['test'] = GetGroupWiseStatEbd(total_embeddings_test, np.array(total_meta_results_test["groups"]))
            
        group_wise_indexes = ["Acc.", "Div.", "Centr. Norm."]
        columns = ["Avg.","Worst", "group0", "group1", "group2", "group3"]
        
        dfs =[]
        for split in ["train", "val", "test"]:
            if split=="train":
                values = [list(self.zs_results[f"{split}"].values())[:-1], 
                        [list(self.group_wise_stat_ebd[split]["pairwise_distance"].values())[0]] + [0] + list(self.group_wise_stat_ebd[split]["pairwise_distance"].values())[1:],
                        [list(self.group_wise_stat_ebd[split]["mean_vector_norm"].values())[0]] + [0] + list(self.group_wise_stat_ebd[split]["mean_vector_norm"].values())[1:]]
                df = pd.DataFrame(values, index=group_wise_indexes, columns = columns)
                df = df.round(3)
                dfs.append(df)
            else:
                values = [list(self.zs_results[f"{split}"].values())[:-1], 
                        [list(self.group_wise_stat_ebd[split]["pairwise_distance"].values())[0]] + [0] + list(self.group_wise_stat_ebd[split]["pairwise_distance"].values())[1:],
                        [list(self.group_wise_stat_ebd[split]["mean_vector_norm"].values())[0]] + [0] + list(self.group_wise_stat_ebd[split]["mean_vector_norm"].values())[1:]]
                df = pd.DataFrame(values, index=group_wise_indexes, columns = columns)
                df = df.round(3)
                dfs.append(df)
        
        if group_mean_ebd is not None:  # group label : (4, 2024) X 3
            add_group_labels_train = [group for group in self.group_wise_stat_ebd['train']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_train = [ebd for ebd in self.group_wise_stat_ebd['train']["mean_vector"].values()]
            add_group_labels_val = [group for group in self.group_wise_stat_ebd['val']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_val = [ebd for ebd in self.group_wise_stat_ebd['val']["mean_vector"].values()]
            add_group_labels_test = [group for group in self.group_wise_stat_ebd['test']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_test = [ebd for ebd in self.group_wise_stat_ebd['test']["mean_vector"].values()]

            group_mean_ebd = (add_group_mean_ebds_train, add_group_mean_ebds_val, add_group_mean_ebds_test,
                              add_group_labels_train, add_group_labels_val, add_group_labels_test)
        
        
        
        if self.args.tl_method == "linear_probing":
            title_suffix= f'([CLIP ZS] Representation ({num_nn_text_ebd} near.)'
        else:
            title_suffix= f'([{self.args.tl_method}] Representation ({num_nn_text_ebd} near.) (Epoch {self.model_epoch}))'
        
        n_mult = 1
        pbar = tqdm(total=n_mult * len(label_types))
        for label_type in label_types:
            # For now just save both classifier ReLU activation layers (for MLP, BaseCNN)
            if save_id_suffix is not None:
                save_id = f'{reduced_dim}d_{label_type}_{save_id_suffix}'
            else:
                save_id = f'{reduced_dim}d_{label_type}'
            
            # print("save_id", save_id)
            
            plot_umap_all(total_embeddings_train, total_embeddings_val, total_embeddings_test, total_meta_results_train, total_meta_results_val, total_meta_results_test,
                  label_type, self.legend_labels_dict, dfs, reduced_dim, method='umap', figsize=figsize, save_id=save_id, save=save, ftype=ftype, title_suffix=title_suffix,
              annotate_points=annotate_points, seed=seed, display_image=True, text_ebd = text_ebd, group_mean_ebd = group_mean_ebd, num_nn_text_ebd = num_nn_text_ebd, set_bbox=set_bbox)
            
            if plot_mds:
                plot_umap_all(total_embeddings_train, total_embeddings_val, total_embeddings_test, total_meta_results_train, total_meta_results_val, total_meta_results_test,
                  label_type, self.legend_labels_dict, dfs, reduced_dim, method='mds', figsize=figsize, save_id=save_id, save=save, ftype=ftype, title_suffix=title_suffix,
              annotate_points=annotate_points, seed=seed, display_image=True, text_ebd = text_ebd, group_mean_ebd = group_mean_ebd, num_nn_text_ebd = num_nn_text_ebd, set_bbox=set_bbox)
            
            pbar.update(1)
            
            
    def plot_umap_for_ca(self, ebd_queue_train, ebd_queue_val, ebd_queue_test, label_type, save_root, save_id, legend_labels_dict, reduced_dim=2, method='umap', figsize=(24, 6), save=True,
              ftype='.png', title_suffix=None, seed=42, display_image=True, text_ebd = True, group_mean_ebd = True, num_nn_text_ebd = 10, remove_prefix=True, set_bbox=False):
        """
        Visualize embeddings with U-MAP
        """
        
        # Final Embedding' Statistics
        total_embeddings_train, total_meta_results_train = ebd_queue_train[-1]
        total_embeddings_val, total_meta_results_val = ebd_queue_val[-1]
        total_embeddings_test, total_meta_results_test = ebd_queue_test[-1]
        
        group_wise_stat_ebd = {}
        print("> Calculating [Group-wise] Statistics...")
        group_wise_stat_ebd['train'] = GetGroupWiseStatEbd(total_embeddings_train, np.array(total_meta_results_train["groups"]), return_dist = False)
        group_wise_stat_ebd['val'] = GetGroupWiseStatEbd(total_embeddings_val, np.array(total_meta_results_val["groups"]), return_dist = False)
        group_wise_stat_ebd['test'] = GetGroupWiseStatEbd(total_embeddings_test, np.array(total_meta_results_test["groups"]), return_dist = False)
            
        group_wise_indexes = ["Acc.", "Centr. Norm."]
        columns = ["Avg.","Worst", "group0", "group1", "group2", "group3"]
        
        dfs =[]
   
        values = [list(total_meta_results_train["group_accs"].values())[:-1], 
                [list(group_wise_stat_ebd["train"]["mean_vector_norm"].values())[0]] + [0] + list(group_wise_stat_ebd["train"]["mean_vector_norm"].values())[1:]]
        df = pd.DataFrame(values, index=group_wise_indexes, columns = columns).round(3)
        dfs.append(df)
        values = [list(total_meta_results_val["group_accs"].values())[:-1], 
                [list(group_wise_stat_ebd["val"]["mean_vector_norm"].values())[0]] + [0] + list(group_wise_stat_ebd["test"]["mean_vector_norm"].values())[1:]]
        df = pd.DataFrame(values, index=group_wise_indexes, columns = columns).round(3)
        dfs.append(df)
        values = [list(total_meta_results_test["group_accs"].values())[:-1], 
                [list(group_wise_stat_ebd["test"]["mean_vector_norm"].values())[0]] + [0] + list(group_wise_stat_ebd["val"]["mean_vector_norm"].values())[1:]]
        df = pd.DataFrame(values, index=group_wise_indexes, columns = columns).round(3)
        dfs.append(df)
        
        if group_mean_ebd is not None:  # group label : (4, 2024) X 3
            add_group_labels_train = [group for group in group_wise_stat_ebd['train']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_train = [ebd for ebd in group_wise_stat_ebd['train']["mean_vector"].values()]
            add_group_labels_val = [group for group in group_wise_stat_ebd['val']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_val = [ebd for ebd in group_wise_stat_ebd['val']["mean_vector"].values()]
            add_group_labels_test = [group for group in group_wise_stat_ebd['test']["mean_vector"].keys()] # Waterbird
            add_group_mean_ebds_test = [ebd for ebd in group_wise_stat_ebd['test']["mean_vector"].values()]

            group_mean_ebd = (add_group_mean_ebds_train, add_group_mean_ebds_val, add_group_mean_ebds_test,
                              add_group_labels_train, add_group_labels_val, add_group_labels_test)
        
        
        labels_train_all = [] # [2, [N_tr]] 
        labels_val_all = [] # [2, [N_val]]
        labels_test_all = [] # [2, [N_test]]
        
        embeddings_train_all = [] # [2, [N_tr]] 
        embeddings_val_all = [] # [2, [N_val]]
        embeddings_test_all = [] # [2, [N_test]]
        
        for idx in [-2, -1]:    
            # Before
            labels_train_all.append(np.array(train_ebd_q[idx][1][label_type+'s']))
            labels_val_all.append(np.array(val_ebd_q[idx][1][label_type+'s']))
            labels_test_all.append(np.array(test_ebd_q[idx][1][label_type+'s']))
            
            embeddings_train_all.append(train_ebd_q[idx][0])
            embeddings_val_all.append(val_ebd_q[idx][0])
            embeddings_test_all.append(test_ebd_q[idx][0])
        
        labels_train_all = np.array(labels_train_all); labels_val_all = np.array(labels_val_all); labels_test_all = np.array(labels_test_all)
        embeddings_train_all = np.array(embeddings_train_all); embeddings_val_all = np.array(embeddings_val_all); embeddings_test_all = np.array(embeddings_test_all)
        
        # return labels_train_all, labels_val_all, labels_test_all, embeddings_train_all, embeddings_val_all, embeddings_test_all
        n_train = len(labels_train_all[-2]) + len(labels_train_all[-1])
        n_val = len(labels_val_all[-2]) + len(labels_val_all[-1])
        n_test = len(labels_test_all[-2]) + len(labels_test_all[-1])
        
        print(f"Number of samples for Visualization : tr : [{n_train}], val : [{n_val}], test : [{n_test}]")
        total_labels = np.concatenate([labels_train_all, labels_val_all, labels_test_all], axis=0) # [6]
        total_embeddings = np.concatenate([embeddings_train, embeddings_val, embeddings_test], axis=0)
        print("ㄴ total embedddings  : ", total_embeddings.shape)

        if text_ebd is not None: # (# of templates, 2048) add to embedding pool
            print("Add [text] embedding to umap-pool")
            text_templates = [list(temp_feat_pair.keys())[0] for temp_feat_pair in text_ebd]
            text_features = [list(temp_feat_pair.values())[0] for temp_feat_pair in text_ebd]
            
            print(f"> Calculate {num_nn_text_ebd} Nearest samples for visualization of [text prompts]")
            nearest_averaged_text_features = []
            for i in range(len(text_features)):
                nearest_indices = find_closest_sample(total_embeddings, text_features[i], top_k=num_nn_text_ebd)
                nearest_averaged_embedding = total_embeddings[nearest_indices].mean(axis=0)
                nearest_averaged_text_features.append(nearest_averaged_embedding)
            
            # # Scale text-embeddings(12.xx) to image-scale (2.xx)
            # [print(compute_vector_norm(feat)) for feat in text_features]
            # norm_full_image = compute_vector_norm(total_embeddings.mean(axis=0))
            # text_features = [(text_feat / compute_vector_norm(text_feat))*norm_full_image for text_feat in text_features]
            # [print(compute_vector_norm(feat)) for feat in text_features]
            
            total_embeddings = np.concatenate([total_embeddings, np.array(nearest_averaged_text_features)], axis=0)
            print("ㄴ total embedddings  : ", total_embeddings.shape)
            # Label : 0, 1
        
        if  group_mean_ebd is not None:
            # 각각 (4, 2024), (4, 2024), (4, 2024)
            # label : 0, 1 ,2, 3
            print("Add [group] (mean) embedding to umap-pool")
            (add_group_mean_ebds_train, add_group_mean_ebds_val, add_group_mean_ebds_test,
                                add_group_labels_train, add_group_labels_val, add_group_labels_test) = group_mean_ebd 
            
            add_group_mean_ebds = np.concatenate([add_group_mean_ebds_train, add_group_mean_ebds_val, add_group_mean_ebds_test], axis=0)
            total_embeddings = np.concatenate([total_embeddings, add_group_mean_ebds])
            
            print("ㄴ total embedddings  : ", total_embeddings.shape)
        
        print("> Projection all the embeddings to [1024d l2-norm sphere]")
        total_embeddings = total_embeddings / np.linalg.norm(total_embeddings, axis=1, keepdims=True)
        print(f"> Start Umap fitting.... (# of samples {total_embeddings.shape[0]})(dim {total_embeddings.shape[1]})")
        if method == 'umap':
            standard_embedding = umap.UMAP(random_state=42, n_components=reduced_dim).fit_transform(total_embeddings)
        else:  # method == 'mds'
            standard_embedding = MDS(n_components=reduced_dim,
                                    random_state=42).fit_transform(total_embeddings)
        
        standard_embedding_train = standard_embedding[: n_train]
        standard_embedding_val = standard_embedding[n_train: n_train + n_val]
        standard_embedding_test = standard_embedding[n_train+n_val : n_train+n_val+n_test]
        
        offset_for_add = n_train+n_val+n_test
        # if (text_ebd is not None) or (group_mean_ebd is not None):
        #     standard_zero_ebd = standard_embedding[offset_for_add]
        #     offset_for_add = offset_for_add + 1 
        #     print("standard [zero] ebd' shape:", standard_zero_ebd.shape)
            
        if text_ebd is not None:
            offset_for_text_ebd = len(text_templates)
            standard_text_ebd = standard_embedding[offset_for_add: offset_for_add + offset_for_text_ebd]
            # print("standard [text]] ebd' shape:", standard_text_ebd.shape)
            offset_for_add = offset_for_add + offset_for_text_ebd
        
        if group_mean_ebd is not None:
            standard_group_mean_ebd_train = standard_embedding[offset_for_add: offset_for_add + 5] # Mean + Group 4
            standard_group_mean_ebd_val = standard_embedding[offset_for_add + 5: offset_for_add + 10]
            standard_group_mean_ebd_test = standard_embedding[offset_for_add + 10: offset_for_add + 15]
            # print("standard [group] ebd' shape:", standard_group_mean_ebd_test.shape)
                        
        fig = plt.figure(figsize=figsize)

        # Zero -> Text Prompt (원점 보정) -> CLIP에서는 안 되네.. 너무 Outlier인듯.
        standard_origin_ebd = standard_embedding.mean(axis=0)
        
        # standard_zero_ebd : all the ploting
        # standard_text_ebd : all the ploting 
        
        
        fig, axs =plt.subplots(2,3, figsize=figsize,  gridspec_kw={'height_ratios': [2.5, 1]})

        for idx, (each_standard_embedding, labels, each_standard_group_mean_ebd, each_df, sub_title) in enumerate(zip([standard_embedding_train, standard_embedding_val, standard_embedding_test],
                                                                            [labels_train, labels_val, labels_test],
                                                                            [standard_group_mean_ebd_train, standard_group_mean_ebd_val, standard_group_mean_ebd_test],
                                                                            passed_dfs,
                                                                            ["Train set", "Val set", "Test set"])):
            # Group : train/val/test
            if label_type == 'confidence':
                colors = np.array(labels)
            else:    
                colors = np.array(labels).astype(int)
                num_colors = len(np.unique(colors))
                if num_colors==2:
                    colors_template = ['midnightblue', 'red']
                elif num_colors==4: # Group
                    colors_template = ['midnightblue', 'darkorange', 'red', 'royalblue']
                colors = [colors_template[val] for val in np.array(labels)] 
            
            if reduced_dim==2:
                # ax = fig.add_subplot(2, 3, idx+1)
                # Continuous
                if label_type == 'confidence':
                    scatter = axs[0][idx].scatter(each_standard_embedding[:, 0], each_standard_embedding[:, 1],
                            c=colors, s=1.0, alpha=1,
                            cmap=plt.cm.get_cmap('coolwarm'))
                # Discrete
                else:
                    axs[0][idx].scatter(each_standard_embedding[:, 0], each_standard_embedding[:, 1], c=colors, s=1.0, alpha=1)
                
                # print("X:", each_standard_embedding[0, 0])
                # print("Y:", each_standard_embedding[0, 1])
                # axs[0][0].annotate("Test~~", xytext=each_standard_embedding[0], xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<-"), size=30)
                if text_ebd is not None:
                    for idx_, ebd in enumerate(standard_text_ebd):
                        if remove_prefix:
                            if set_bbox:
                                axs[0][idx].annotate(f'"{text_templates[idx_].split("a photo of ")[-1]}"', xytext=ebd, xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<|-"), bbox=dict(boxstyle="round4", fc="w"))
                            else:
                                axs[0][idx].annotate(f'"{text_templates[idx_].split("a photo of ")[-1]}"', xytext=ebd, xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<|-") ) # bbox=dict(boxstyle="round4", fc="w")
                        else:
                            if set_bbox:
                                axs[0][idx].annotate(f'"{text_templates[idx_]}"', xytext=ebd, xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<|-"), bbox=dict(boxstyle="round4", fc="w"))
                            else:
                                axs[0][idx].annotate(f'"{text_templates[idx_]}"', xytext=ebd, xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<|-")) # bbox=dict(boxstyle="round4", fc="w")
                if group_mean_ebd is not None:
                    for idx_, ebd in enumerate(each_standard_group_mean_ebd):
                        if idx_ ==0:
                            continue # Pass the average vector.
                        axs[0][idx].annotate(f"Group {idx_-1}", xytext=ebd, xy=standard_origin_ebd, arrowprops=dict(arrowstyle="<-"))
                # ax = fig.add_subplot(2, 3, idx+4)
                axs[1][idx].axis('tight')
                axs[1][idx].axis('off')
                table = axs[1][idx].table(cellText=each_df.values, colLabels=each_df.columns, rowLabels=each_df.index, loc='center')
                # ax.set_box_aspect(1)

                
            table.scale(1, 2)  # Adjust the scale factors to control the size of the table (매커니즘 몰라ㅏ)
            
                
            legend_elements = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in colors_template]
            legend_labels = [legend_labels_dict[label_type][label] for label in range(len(np.unique(labels)))]
            axs[0][idx].legend(legend_elements, legend_labels)
                
            axs[0][idx].set_title(sub_title)
                    
        suffix = '' if title_suffix is None else f' {title_suffix}'
        plt.suptitle(f'Color by [{label_type}] labels{suffix}', size=20)
        # plt.tight_layout(rect=[0, 0, 1, 0.95]) 
            
        if save:
            fpath = f'{save_id}{ftype}'
            if not os.path.exists(fig_save_root):
                os.mkdir(fig_save_root)
            
            fpath = os.path.join(fig_save_root, fpath)
            plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
            print(f'Saved {method} to {fpath}!')
            
        if display_image:
            plt.show()
        plt.close('all')
        del standard_embedding
        
    


In [64]:
train_ebd_q = train_ebd_q[:4]
val_ebd_q = val_ebd_q[:4]
test_ebd_q = test_ebd_q[:4]



In [111]:
vis_handler = VisHandler(opt)

vis_handler.SaveTextEmbeddings(opt.text_embedding_dir) # class 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)
vis_handler.SaveTextEmbeddings(opt.text_spurious_embedding_dir) # spurious 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)
# vis_handler.SaveTextEmbeddings(opt.text_group_embedding_dir) # group 임베딩 경로 (unnormalized!) (c.f. clip_inference_including_group_with_unnorm.py)

num_nn_text_ebd = 10 # 각각의 Text embedding에서 제일 가까운 [num_nn_text_ebd] 개수의 이미지 임베딩을 뽑아, 평균낸 임베딩을 해당 Text embedding의 visualization에 사용함. 
set_bbox=False # True: 가독성 그나마 좋아지나, 가끔 가릴 때 있음.

In [112]:
run_name = "임의"
fig_save_root = os.path.join("figure", run_name)

fig_save_root    
save_id = f"class_{epoch}ep_ca"
title_suffix = f"After [Contrastive] (Epoch: {epoch})"
train_l, val_l, test_l, train_ebd, val_ebd, test_ebd = vis_handler.plot_umap_for_ca(train_ebd_q, val_ebd_q, test_ebd_q, 'target', fig_save_root, save_id, 
                             vis_handler.legend_labels_dict, title_suffix = title_suffix,
                             text_ebd = True, num_nn_text_ebd = num_nn_text_ebd)

> Calculating [Group-wise] Statistics...


In [133]:
train_l = train_l.flatten()
val_l = val_l.flatten()
test_l = test_l.flatten()

train_ebd = train_ebd.reshape(-1, train_ebd.shape[-1])
val_ebd = val_ebd.reshape(-1, val_ebd.shape[-1])
test_ebd = test_ebd.reshape(-1, test_ebd.shape[-1])

print(train_l.shape)
print(val_l.shape)
print(test_l.shape)
print(train_ebd.shape)
print(val_ebd.shape)
print(test_ebd.shape)

total_labels = np.concatenate([train_l, val_l, test_l])
print(total_labels.shape)
total_embeddings = np.concatenate([train_ebd, val_ebd, test_ebd])
print(total_embeddings.shape)

(9590,)
(2398,)
(11588,)
(9590, 1024)
(2398, 1024)
(11588, 1024)
(23576,)
(23576, 1024)


In [17]:
def GetGroupWiseStatEbd(embeddings, group_labels, return_dist = True):
    # in Each group
    statistics = {
        'mean_vector' : {},
        'mean_vector_norm' : {},
        'pairwise_distance' : {},
    }
    
    # Full datasets
    mean_vector = compute_mean_vector(embeddings, axis=0)
    vector_norm = compute_vector_norm(mean_vector)
    if return_dist:
        pairwise_distance = compute_averaged_pairwise_distance(embeddings)
    
    statistics['mean_vector']['full'] = mean_vector
    statistics['mean_vector_norm']['full'] = vector_norm
    if return_dist:
        statistics['pairwise_distance']['full'] = pairwise_distance

    for group in np.unique(group_labels): # 0, 1, 2, 3
        group_indices = np.where(group_labels == group)[0]
        group_embeddings = embeddings[group_indices]
        
        mean_vector = compute_mean_vector(group_embeddings, axis=0)
        vector_norm = compute_vector_norm(mean_vector)
        if return_dist:
            pairwise_distance = compute_averaged_pairwise_distance(group_embeddings)
        
        statistics['mean_vector'][group] = mean_vector
        statistics['mean_vector_norm'][group] = vector_norm
        if return_dist:
            statistics['pairwise_distance'][group] = pairwise_distance
    
    return statistics

In [18]:
def validate_adapter_with_return(opt, val_loader, classifier, criterion, get_yp_func, train_group_ratio, target, print_label='Test'):
    """
    For adapter.
    - Validation
    - Return data (Adapted embeddings, labels, group-labels, ....)
    """
    
    total_embeddings = []
    total_labels = []
    total_spuriouss = []
    total_groups = []
    # total_confidences = []
    total_predictions = [] # Zero-shot
    
    
    classifier.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()
    acc_groups = {g_idx : AverageMeter() for g_idx in range(val_loader.dataset.n_groups)}

    with torch.no_grad():
        end = time.time()
        for idx, data in enumerate(val_loader):
            embeddings, all_labels, img_filenames = data # all_labels.keys() : ['class', 'group', 'spurious', 'ebd_pred'(CLIP-zeroshot)] 
            labels = all_labels[target] # target : one of [class, spurious, group]
            
            groups = all_labels['group'] # For evaluating group accuracy (and further developing group-information-aware approaches)
            places = all_labels["spurious"]
            
            
            embeddings = embeddings.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]
            
            # NOTE Embedding 추가
            # embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
            
            # forward
            assert "adapter" in opt.tl_method
            
            image_features = classifier.adapter(embeddings)
            
            image_features_normalized = image_features / image_features.norm(dim=-1, keepdim=True) # Normalized (B, 1024)
            text_features_normalized = classifier.text_features / classifier.text_features.norm(dim=0, keepdim=True) # Normalized # (1024, 2)
            logits = image_features_normalized @ text_features_normalized / classifier.temperature # (B, 1024) X (1024, 2) = # (B, 2)
        
            loss = criterion(logits, labels)
            
            
            predicted = logits.argmax(axis=1)
        
            # save
            total_labels.extend(labels.cpu().numpy())
            total_groups.extend(groups.numpy())
            total_spuriouss.extend(places.numpy())
            total_predictions.extend(predicted.cpu().numpy())
            total_embeddings.extend(image_features.cpu().numpy())
    
            
            # update metric
            losses.update(loss.item(), bsz)
            acc1 = accuracy(logits, labels, bsz)
            acc.update(acc1, bsz)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            # Update acc dict
            update_dict(acc_groups, labels, groups, logits)
        
            if opt.watch_batch_results:
                if (idx+1) % opt.print_freq == 0:
                    print(f'{print_label}: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Acc@1 {acc.val:.3f} ({acc.avg:.3f})'.format(
                        idx, len(val_loader), batch_time=batch_time,
                        loss=losses, acc=acc))
    
    total_embeddings = np.array(total_embeddings) # (# of full data, feat_dim)

    # 딱히 List로 반환해도 되는 친구들이라 List로 해놨었나.. 
    total_meta_results = {"targets" : total_labels, "spuriouss": total_spuriouss, "groups" : total_groups, 
                             "predictions": total_predictions}
                    
    group_acc = get_results(acc_groups, get_yp_func)

    # NOTE Add Weighted mean acc.
    groups = range(val_loader.dataset.n_groups) # 0, 1, 2, 3
    group_acc_indiv =  [group_acc[f"acc_{get_yp_func(g)[0]}_{get_yp_func(g)[1]}"] for g in groups]
    weighted_mean_acc = (np.array(group_acc_indiv) * np.array(train_group_ratio)).sum() # Weighted Sum \
    
    group_acc["weighted_mean_acc"] = weighted_mean_acc
    group_acc = {key: group_acc[key] for key in new_order_for_print}
    group_acc = {key: np.round(value, 4) for key, value in group_acc.items()}
    print(f"{print_label}:", str(group_acc))

    return (losses.avg, acc.avg, group_acc), (total_embeddings, total_meta_results)


In [11]:
best_acc = 0
best_epoch = 0
best_model = None
# opt = parse_option()

# Main Ce code
if opt.tl_method=="contrastive_adapter": 
    print("> set and load Contrastive data-handler")
    print('========================================================================')
    sliced_data_indices, sliced_data_correct = compute_slice_indices(trainset)
    contrastive_points = prepare_contrastive_points(trainset,sliced_data_indices,sliced_data_correct)
    slice_anchors, slice_negatives, positives_by_class, all_targets = contrastive_points
    
    
    adjust_num_pos_neg_(positives_by_class, slice_negatives, opt)
    contrastive_loss = SupervisedContrastiveLoss(opt)
    
    
    # Get contrastive batches for first epoch
    original_contrastive_batch_samples = construct_contrastive_data(slice_anchors,slice_negatives,positives_by_class, opt) # #[(254, 1719), (94, 1)]
    contrastive_dataloader = load_contrastive_loader(trainset, original_contrastive_batch_samples, opt, True)
    
    negatives_by_class = GetNegativesByClass(trainset, positives_by_class)
    weights_resampled_ce = GetResampledWeightsCE(trainset, positives_by_class, negatives_by_class, opt)
    ce_sampler = WeightedRandomSampler(weights = weights_resampled_ce, num_samples = len(trainset), replacement=True) # num_samples = len(trainset) -> oversampling 한 만큼 major group에서 unseen-sample 나옴
    resampled_train_loader = DataLoader(trainset, sampler=ce_sampler, batch_size=opt.batch_size, num_workers=8)
    skim_dataloader_by_group(resampled_train_loader)
    
    print('========================================================================')

# build optimizer
print("Set Optimizer: SGD (default)")
print('========================================================================')
optimizer = set_optimizer(opt, classifier)

# training routine
train_losses = []
train_losses_cl = []
train_accs = []
train_group_accs = []

val_losses = []
val_accs = []
val_group_accs = []

test_losses = [] # NOTE: Don't peek ! 
test_accs = [] # NOTE: Don't peek ! 
test_group_accs = [] # NOTE: Don't peek ! 

# Validate train
((_, _, _), 
     (total_embeddings, total_meta_results)) = validate_adapter_with_return(opt, train_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target=opt.train_target, print_label=f'Train(Initial Adapter)')


# entire training
for epoch in range(1, opt.epochs + 1):
    adjust_learning_rate(opt, optimizer, epoch)
    print(f'--- Epoch {epoch} ---')
    
    
    # train one epoch
    loss_cl = train_one_epoch_cl(opt, contrastive_dataloader, classifier, contrastive_loss,
                        optimizer, epoch, print_label=f'Train(Contrastive Learning)')
    train_losses_cl.append(loss_cl)
    
    # Validate train
    ((_, _, train_ca_group_acc), 
     (train_total_embeddings, train_total_meta_results)) = validate_adapter_with_return(opt, train_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target=opt.train_target, print_label=f'Train(After CA)')
    ((_, _, val_ca_group_acc), 
     (val_total_embeddings, val_total_meta_results)) = validate_adapter_with_return(opt, val_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target=opt.train_target, print_label=f'Val(After CA)')
    ((_, _, test_ca_group_acc), 
     (test_total_embeddings, test_total_meta_results)) = validate_adapter_with_return(opt, test_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target='class', print_label=f'Test(After CA)')
    
    train_total_meta_results['training_method'] = "CA"; train_total_meta_results['group_accs'] = train_ca_group_acc; 
    val_total_meta_results['training_method'] = "CA"; val_total_meta_results['group_accs'] = val_ca_group_acc; 
    test_total_meta_results['training_method'] = "CA"; test_total_meta_results['group_accs'] = test_ca_group_acc; 
    
    append_like_queue(train_ebd_q, (train_total_embeddings, train_total_meta_results), opt.max_length_ebd_queue)
    append_like_queue(val_ebd_q, (val_total_embeddings, val_total_meta_results), opt.max_length_ebd_queue)
    append_like_queue(test_ebd_q, (test_total_embeddings, test_total_meta_results), opt.max_length_ebd_queue)
    
    loss, acc, group_acc = train_one_epoch(opt, resampled_train_loader, classifier, ce_loss,
                        optimizer, epoch, get_yp_func, target=opt.train_target, print_label=f'Train(Cross entropy)(for all training set)')
    train_losses.append(loss); train_accs.append(acc); train_group_accs.append(group_acc)
    
    # Validate train
    ((_, _, train_ce_group_acc), 
     (train_total_embeddings, train_total_meta_results)) = validate_adapter_with_return(opt, train_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target=opt.train_target, print_label=f'Train(After CE)')
    ((val_loss, val_acc, val_ce_group_acc), 
     (val_total_embeddings, val_total_meta_results)) = validate_adapter_with_return(opt, val_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target=opt.train_target, print_label=f'Val(After CE)')
    ((test_loss, test_acc, test_ce_group_acc), 
     (test_total_embeddings, test_total_meta_results)) = validate_adapter_with_return(opt, test_loader, classifier, ce_loss, get_yp_func, train_group_ratio, target='class', print_label=f'Test(After CE)')
    
    train_total_meta_results['training_method'] = "CE"; train_total_meta_results['group_accs'] = train_ce_group_acc; 
    val_total_meta_results['training_method'] = "CE"; val_total_meta_results['group_accs'] = val_ce_group_acc; 
    test_total_meta_results['training_method'] = "CE"; test_total_meta_results['group_accs'] = test_ce_group_acc; 
    
    append_like_queue(train_ebd_q, (train_total_embeddings, train_total_meta_results), opt.max_length_ebd_queue)
    append_like_queue(val_ebd_q, (val_total_embeddings, val_total_meta_results), opt.max_length_ebd_queue)
    append_like_queue(test_ebd_q, (test_total_embeddings, test_total_meta_results), opt.max_length_ebd_queue)
    
    val_losses.append(val_loss); val_accs.append(val_acc); val_group_accs.append(val_ce_group_acc)
    test_losses.append(test_loss); test_accs.append(test_acc); test_group_accs.append(test_ce_group_acc)
    
    # update best epoch by worst_group accuracy (default)
    if val_ce_group_acc['worst_acc'] > best_acc:
        best_acc = val_ce_group_acc['worst_acc']
        best_epoch = epoch
        best_model = copy.deepcopy(classifier)
    
    
    
    if opt.re_shuffle_ca_loader:
        contrastive_dataloader = load_contrastive_loader(trainset, original_contrastive_batch_samples, opt, True)


print('========================================================================')
print("> end of training. \n")
print('best epoch : {}'.format(best_epoch))

best_train_group_acc = train_group_accs[best_epoch-1]
best_val_group_acc = val_group_accs[best_epoch-1]
best_test_group_acc = test_group_accs[best_epoch-1]

print(f'best training accuracy on [{opt.train_target}]: {best_train_group_acc}')
print(f'best validation accuracy on [{opt.train_target}]: {best_val_group_acc}')
print(f'best test accuracy on [{opt.train_target}]: {best_test_group_acc}')

  sliced_data_incorrect = np.array(sliced_data_incorrect)


> set and load Contrastive data-handler
>> Slice 0, target: 0, counts: 3588
>> Slice 0, target: 1, counts: 254
Slice 0 % incorrect: 6.6111 %
ix class counts [1] [254]
nix class counts [0] [3588]
Slice 0 # negative (correct): 3588
Slice 0 % negative (correct): 93.3889 %
Unique negative targets: (array([0]), array([3588]))
nix class counts(for positive) [0] [3588]
Slice 0 # Positive: (for 'other' slice) 3588
>> Slice 1, target: 0, counts: 94
>> Slice 1, target: 1, counts: 859
Slice 1 % incorrect: 9.8636 %
ix class counts [0] [94]
nix class counts [1] [859]
Slice 1 # negative (correct): 859
Slice 1 % negative (correct): 90.1364 %
Unique negative targets: (array([1]), array([859]))
nix class counts(for positive) [1] [859]
Slice 1 # Positive: (for 'other' slice) 859
> Add Easy Negatives samples (0523)
>> % Negatives for Anoter Slice 1): 6.6111 %
859 -> 1113
> Add Easy Negatives samples (0523)
>> % Negatives for Anoter Slice 0): 9.8636 %
3588 -> 3682
given args: number of anchors: 1
given ar

Generating data from slice 0: 100%|██████████| 254/254 [00:00<00:00, 9795.54it/s]


batch_samples_per_slice.shape (254, 1973)


Generating data from slice 1: 100%|██████████| 94/94 [00:00<00:00, 9503.10it/s]

batch_samples_per_slice.shape (94, 1973)
No balancing contrastive samples (Focusing on class with more prediction errors)
Shuffle for concatenated-contrastive-batch
First 10 Anchor indices:  [  97    4 4594 2419 3506  177  152 1510 1262 2317]
Shape of Contrastive Batch Samples :  (348, 1973)





len(resampled_set.targets) 686604
> batchsize of contrastive data loader : 63136
> len of contrastive dataset : 686604
> Class 0
# of samples : 3682
# of positives :  3588
# of negatives :  94
> Class 1
# of samples : 1113
# of positives :  859
# of negatives :  254
Re-sampling for [Cross Entropy Loader]
> Class 0 Weight: 38.170
> Class 1 Weight: 3.382
> Final sampling weights set : [ 1.          3.38188976 38.17021277]
>> Corresponding Counts : [4447  254   94] / 4795
Distribution of some mini batch 
[70, 33, 25]
[75, 25, 3, 25]
[73, 32, 2, 21]
[73, 26, 4, 25]
[85, 18, 2, 23]
[85, 18, 2, 23]
[85, 20, 2, 21]
[80, 23, 1, 24]
[77, 29, 22]
[85, 21, 2, 20]
[72, 25, 1, 30]
Set Optimizer: SGD (default)
Train(Initial Adapter): {'weighted_mean_acc': 0.9449, 'worst_acc': 0.3929, 'acc_0_0': 0.9917, 'acc_0_1': 0.9076, 'acc_1_0': 0.3929, 'acc_1_1': 0.8259, 'mean_acc': 0.9449}
--- Epoch 1 ---
>>CA Optimzer.step
Loss in Train(Contrastive Learning): 0.252 Pos:0.0706, Neg: 0.0713
Train(After CA): {'we