In [None]:
# If U are using SageMaker Prepare for the dataset
!pip install awscli
!aws s3 cp s3://handata/ref_youtube_audio/ ref_youtube_audio/ --recursive

In [None]:
!pip install transformers
!pip install -U openai-whisper
!pip install librosa

In [1]:
from transformers import AutoFeatureExtractor, WhisperForAudioClassification
import torch
import torch.nn as nn
import whisper
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import whisper
import pandas as pd
from categories import ytvos_category_dict
import numpy as np
from util import read_aws_json,read_aws_wav,read_local_json,read_local_wav
import logging
from torch import optim
from losses import get_loss_func
from utils.evaluate import Evaluator
from util import infoNCE_loss
import random
from tqdm.notebook import tqdm
from enum import Enum
from sklearn.metrics import f1_score,precision_recall_curve,precision_score,recall_score,accuracy_score,balanced_accuracy_score
SageMaker = False
Local = True
ROOT = 'C:/Users/Administrator/Desktop/CLUL-main/data/'
SAVEDIR = 'C:/Users/Administrator/Desktop/CLUL-main/run/'

In [2]:
class Audio_Encoder(nn.Module):
    def __init__(self, feature_extractor, model, num_class=66,dropout_prob=0.2,pool_num = 100,bias = True):
        super().__init__()
        self.num_class = num_class
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.feature_extractor = feature_extractor
        self.encoder = model.encoder
        for name, param in self.encoder.named_parameters():
          param.requires_grad = False
        self.projector = nn.Linear(in_features=768, out_features=256, bias=True)
        self.classifier = nn.Linear(256, num_class)

        self.avg_pool = nn.AvgPool2d(kernel_size=(pool_num,1), stride=(pool_num,1))
        # self.norm_layer = nn.LayerNorm(256, eps=1e-5, bias=True)
        self.batchnorm = nn.BatchNorm1d(2048, affine=False)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.dropout2 = nn.Dropout(0.5)

        self.fc1 = nn.Linear(1500//pool_num * 256, 2048)
        self.fc2 = nn.Linear(2048, 256)
        self.fc3 = nn.Linear(256, num_class)

    def forward(self, audios):
        input_features = []
        for audio in audios:

            feature = self.feature_extractor(audio.cpu(),sampling_rate=16000,return_tensors="pt").input_features
            input_features.append(feature)

        input_features = torch.cat(input_features, dim=0).to(self.device)
        hidden_states = self.encoder(input_features)
        # hidden_states = self.projector(hidden_states)
        # pooled_output = hidden_states.mean(dim=1)
        # logits = self.classifier(pooled_output)

        x = self.avg_pool(hidden_states)

        x = self.projector(x)
        # x = self.positionencoding(x)
        feature = x.reshape(x.shape[0], -1)

        x = self.dropout(feature)

        x = self.fc1(x)
        # x = self.batchnorm(x)
        x = self.dropout(x)
        x = self.fc2(x)

        x = self.dropout(x)
        x = self.fc3(x)

        output_dict = {
            'clipwise_output': x,
            'feature': feature,
            'embedding': hidden_states}

        return output_dict

class ytvos_Dataset(Dataset):
    def __init__(self, data_frame: pd.DataFrame, sr=44100, num_class=66):
        self.data_frame = data_frame
        self.sr = sr
        self.num_class = num_class
        self.data_root = '/home/user/SED_Adaptation_Classifier-main/data/ref_youtube_audio/audio'

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        audio_name = self.data_frame.iloc[index]["video"]
        audio_id = self.data_frame.iloc[index]["audio"]
        audio_path = 'ref_youtube_audio/audio' + '/' + audio_name + '/' + audio_id + '.wav'
        name = audio_name + self.data_frame.iloc[index]["exp"]

        
        waveform = read_local_wav(ROOT + audio_path)
#         waveform = whisper.load_audio(audio_path,sr = 16000)

        tag = self.data_frame.iloc[index]["category"]
        target = ytvos_category_dict[self.data_frame.iloc[index]["category"]]
        target = np.eye(self.num_class)[target]
        data_dict = {'audio_name': name, 'waveform': waveform, 'target': target, 'tag': tag}

        return data_dict

def get_datalist(cur_iter):
        task_id = cur_iter
        task_train_metas = []
        task_test_metas = []

       
        metas = read_local_json(ROOT + 'task_split_1/metas.json')['metas']
        tasks = read_local_json(ROOT + 'task_split_1/task{}.json'.format(task_id))[str(task_id)]

        for category,task_metas_dict in tasks.items():
            train_ids = task_metas_dict['train']
            test_ids = task_metas_dict['test']
            for train_id in train_ids:
                task_train_metas.append(metas[train_id])
            for test_id in test_ids:
                task_test_metas.append(metas[test_id])

        return task_train_metas,task_test_metas
    
def default_collate_fn(batch):
    audio_name = [data['audio_name'] for data in batch]
    waveform = [torch.from_numpy(data['waveform']) for data in batch]
    target = [data['target'] for data in batch]

    # waveform = torch.FloatTensor(waveform)
    # waveform = pad_sequence(waveform, batch_first=True, padding_value=0)
    target = torch.FloatTensor(target)

    return {'audio_name': audio_name, 'waveform': waveform, 'target': target}

def get_dataloader(data_frame, dataset, split, batch_size, num_class, num_workers=8):
    assert dataset == "ref_youtube_audio"
    dataset = ytvos_Dataset(data_frame=data_frame)
    return DataLoader(dataset=dataset, batch_size=batch_size,
                      shuffle=True, drop_last=False,
                      num_workers=num_workers, collate_fn=default_collate_fn)

def get_train_test_dataloader(batch_size, n_worker, train_list, test_list):
    train_loader = get_dataloader(pd.DataFrame(train_list), 'ref_youtube_audio', split='train', batch_size=batch_size, num_class=66,
                                  num_workers=n_worker)
    test_loader = get_dataloader(pd.DataFrame(test_list), 'ref_youtube_audio', split='test', batch_size=batch_size, num_class=66,
                                 num_workers=n_worker)
    return train_loader, test_loader


In [3]:
class CLUL:
    def __init__(self,batch_size = 16,lr = 1e-3,memory_size = 500,
                 forget_size = 100,epoch=3,loss ='focal_loss',
                 total_class_num = 65,mode = 'CLUL _no_foget_bank',
                 patience = 10,n_worker = 0,
                 **kwargs):
        feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-small")
        whisper_model = whisper.load_model("small")
        
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = Audio_Encoder(feature_extractor, whisper_model).to(self.device)

        self.batch_size = batch_size
        self.lr = lr
        self.epoch = epoch
        self.logger = logging.getLogger()
        self.forget_list = []
        self.memory_list = []
        self.memory_size = memory_size
        self.forget_size = forget_size
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
        self.criterion = get_loss_func(loss)
        self.num_pretrain_class = 0
        self.evaluator = Evaluator(self.model, self.num_pretrain_class, self.device)
        
        self.mode = mode
        self.patience = patience
        self.counter = 0
        self.total_class_num = total_class_num
        self.forget_label = total_class_num
        self.n_worker = n_worker
        self.cltask = {
            'task0':[15, 17, 60, 50, 32, 24, 63, 36, 31, 40, 52, 4, 25],
            "task1":[48, 54, 35, 62, 13, 42, 37, 49, 51, 45, 44, 14, 5],
            "task2":[46, 18, 57, 28, 11, 30, 61, 27, 22, 2, 29, 0, 19],
            "task3":[3, 59, 10, 12, 8, 1, 26, 23, 34, 58, 64, 56, 41],
            "task4":[47, 20, 53, 39, 9, 21, 16, 38, 33, 43, 6, 7, 55]
        }
        self.ultask = {
            "ul_task0":[],
            "ul_task1":[15,17,60],
            "ul_task2": [48, 54, 35],
            "ul_task3":[46, 18, 57],
            "ul_task4" : [3, 59, 10],
            "ul_task5" : [47, 20, 53]
        }
           
    def evaluate(self,model_path,cur_iter):
        self.change_model(model_path)
        train_list,test_list = self.get_train_test_datalist(cur_iter)
        _, test_loader = get_train_test_dataloader(self.batch_size, self.n_worker, train_list, test_list)
        y_true,y_pred = self.evaluator.evaluate(test_loader)
        
        cl_class_label,ul_class_label = self.get_cl_ul_class_label(cur_iter)
        # statistics = self.calculate_metrics(y_true,y_pred,cl_class_label,ul_class_label)
        # print(y_true,y_pred,cl_class_label,ul_class_label)
        # return statistics
        return y_true,y_pred,cl_class_label,ul_class_label

    def get_train_test_datalist(self,cur_iter):
        train_list,test_list = get_datalist(cur_iter)
        for i in range(cur_iter):
            _,test_list_ = get_datalist(i)
            test_list += test_list_

        return train_list,test_list
        
    def train(self,cur_iter):
        streamed_list,test_list = get_datalist(cur_iter)
        train_list = streamed_list + self.memory_list
        random.shuffle(train_list)
        train_loader, test_loader = get_train_test_dataloader(self.batch_size, self.n_worker, train_list, test_list)

        self.logger.info(f"Streamed samples: {len(streamed_list)}")
        self.logger.info(f"In-memory samples: {len(self.memory_list)}")
        self.logger.info(f"Train samples: {len(train_list)}")
        self.logger.info(f"Test samples: {len(test_list)}")
        # logger.info(f"Model: {self.model}")
        self.logger.info(f"Optimizer: {self.optimizer}")
        acc_list = []
        best = {'acc': 0, 'epoch': 0,'f1_score':0}

        for epoch in range(self.epoch):
            mean_loss = 0
            for idx,batch_data_dict in enumerate(tqdm(train_loader)):
                batch_data_dict['waveform'] = batch_data_dict['waveform']
                batch_data_dict['target'] = batch_data_dict['target'].to(self.device)

                # Forward
                self.model.train()

                batch_output_dict = self.model(batch_data_dict['waveform'])
                """{'clipwise_output': (batch_size, classes_num), ...}"""
                batch_target_dict = {'target': batch_data_dict['target']}
                """{'target': (batch_size, classes_num)}"""
                # Loss
                
                loss = self.criterion(batch_output_dict, batch_target_dict)
                self.logger.info(f'Batch Training Initial Loss: {loss}')
                if idx % 10 == 0:
                    print(f'Epoch:{epoch},Batch {idx} Loss: {loss}')
                # Backwards
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                loss = loss.item()

                mean_loss += loss
            epoch_loss = mean_loss / len(train_loader)
            self.logger.info(f'Epoch {epoch} | Training Loss: {epoch_loss}')
            print(f'Epoch {epoch} | Training Loss: {epoch_loss}')
            # Evaluate
            y_true,y_pred = self.evaluator.evaluate(test_loader)
            weighted_accuracy = balanced_accuracy_score(y_true,y_pred)
            accuracy = accuracy_score(y_true,y_pred)
            # ave_f1_score = np.mean(test_statistics['f1_score'])
            # ave_acc = np.mean(test_statistics['accuracy'])
            # acc_list.append(ave_acc)
            # self.logger.info(f"Epoch {epoch} | Evaluation Accuracy: {ave_acc}|Evaluation f1_score: {ave_f1_score}")
            # self.logger.info(f'Current Accuracy: {ave_acc} in epoch {epoch}.|Current f1_score: {ave_f1_score} in epoch {epoch}.')
            # print(f"Task {cur_iter} | Epoch {epoch} | Evaluation Accuracy: {ave_acc}|Evaluation f1_score: {ave_f1_score}|Evaluation precision {test_statistics['precision']}")
            

            if weighted_accuracy > best['weighted_accuracy']:
                best['weighted_accuracy'] = weighted_accuracy
                best['accuracy'] = accuracy
                best['epoch'] = epoch
                self.logger.info(f'Best Accuracy: {accuracy} in epoch {epoch}.|Best weighted_accuracy: {weighted_accuracy} in epoch {epoch}.')
                selected_state_dict = {}
                for name, param in self.model.named_parameters():
                    if 'projector' in name or 'classifier' in name or 'fc' in name and ('encoder' not in name):
                        selected_state_dict[name] = param
                torch.save(selected_state_dict,SAVEDIR + '{}/task{}best_epoch{}.pt'.format(self.mode,cur_iter,epoch))
                self.counter = 0
            else:
                self.counter += 1
                self.logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}.')
                if self.counter >= self.patience:
                    break
        print(f"Task {cur_iter} | Best Epoch {best['epoch']} | Best Accuracy: {best['accuracy']}|Best weighted_accuracy: {best['weighted_accuracy']}")
        return 
    
    def change_model(self, path):
        checkpoint_dict = torch.load(path)
        for name, param in self.model.named_parameters():
            if name in checkpoint_dict:
                param.data.copy_(checkpoint_dict[name])
                
    def equal_class_sampling(self, samples, num_class):
        class_list = [self.cltask["task0"], self.cltask["task1"],self.cltask["task2"],self.cltask["task3"],self.cltask["task4"]]
        cur_class_list = []
        for i in range(num_class//13):
            cur_class_list += class_list[i]
        mem_per_cls = self.memory_size // num_class
        sample_df = pd.DataFrame(samples)

        # Warning: assuming the classes were ordered following task number.
        ret = []
        for y in cur_class_list:
            cls_df = sample_df[(sample_df["category"].map(ytvos_category_dict)) == y]
            ret += cls_df.sample(n=min(mem_per_cls, len(cls_df))).to_dict(
                orient="records"
            )

        num_rest_slots = self.memory_size - len(ret)
        if num_rest_slots > 0:
            self.logger.warning("Fill the unused slots by breaking the equilibrium.")
            ret += (
                sample_df[~sample_df.exp.isin(pd.DataFrame(ret).exp)]
                .sample(n=num_rest_slots)
                .to_dict(orient="records")
            )

        num_dups = pd.DataFrame(ret).exp.duplicated().sum()
        if num_dups > 0:
            self.logger.warning(f"Duplicated samples in memory: {num_dups}")

        return ret

    def get_data(self, infer_loader, augment):
        Z, Z_, predict_list = [], [], []
        self.model.eval()
        with torch.no_grad():
            for id, data in enumerate(infer_loader):
                wavs = data['waveform']
                aug_wavs = []
                for wav in wavs:
                    aug_wav = augment(wav.unsqueeze(0).unsqueeze(0), sample_rate=1600)
                    aug_wavs.append(torch.as_tensor(aug_wav.squeeze(0).squeeze(0), dtype=torch.float32))

                output_dict = self.model(data['waveform'])
                aug_output_dict = self.model(aug_wavs)

                for z, z_ in zip(output_dict['feature'], aug_output_dict['feature']):
                    Z.append(z)
                    Z_.append(z_)

                clipwise_output = output_dict['clipwise_output']
                pres = np.argmax(clipwise_output.detach().cpu(), axis=1)
                target = np.argmax(data['target'].cpu(), axis=1)

                for pre in pres: predict_list.append(pre.item())

            class_label_dic = self.save_indexes(predict_list)
        return Z, Z_, class_label_dic, predict_list
    
    def save_indexes(self,arr):
        index_dict = {}
        for idx, num in enumerate(arr):
            if num in index_dict:
                index_dict[num].append(idx)
            else:
                  index_dict[num] = [idx]
        return index_dict

    def forget_label_set(self,y_true,ul_class_label):
        index_row = torch.argmax(y_true,dim=1)
        for r ,c in enumerate(index_row):
            if int(c) in ul_class_label:
                with torch.no_grad():
                    y_true[r][c] = torch.tensor(0.0, dtype=torch.float32)
                    y_true[r][-1] = torch.tensor(1.0, dtype=torch.float32)
        return y_true
    
    def class_infoNCE(self, Z, Z_, class_label_dic, predict_list, temperature):
        ## You can change the method to calculate NCEs
        NCEs = []
        # print('This is cclass_label_dic',class_label_dic)
        for id in range(len(predict_list)):
            label = predict_list[id]
            same_label_list = class_label_dic[label]
            class_z = [Z[i] for i in same_label_list if i != id]
            class_z_ = [Z_[i] for i in same_label_list]

            positive_pair = class_z + class_z_

            positive_similarities = F.cosine_similarity(Z[id].unsqueeze(0), torch.stack(positive_pair)) / 2 + 0.5
            # print('This is postitive pair info',Z[id].unsqueeze(0).shape,torch.stack(positive_pair).shape,positive_similarities.shape)
            positive_value = torch.exp(positive_similarities / temperature).sum() / len(positive_pair)
            # print(positive_similarities,positive_value)
            neg_labels = [i for i in list(class_label_dic.keys()) if i != label]

            negative_values = 0
            for neg_label in neg_labels:
                neg_label_list = class_label_dic[neg_label]
                neg_z = [Z[i] for i in neg_label_list]
                neg_z_ = [Z_[i] for i in neg_label_list]
                negative_pair = neg_z + neg_z_
                negative_similarities = F.cosine_similarity(Z[id].unsqueeze(0), torch.stack(negative_pair)) / 2 + 0.5
                # print('This is negative pair info',Z[id].unsqueeze(0).shape,torch.stack(negative_pair).shape,negative_similarities.shape,len(negative_pair))
                negative_value = torch.exp(negative_similarities / temperature).sum() / len(negative_pair)
                # print(negative_similarities,negative_value)
                negative_values += negative_value

            NCE = -torch.log(positive_value / (positive_value + negative_values))
            # print('positive_value',positive_value,'negative values', negative_values,'this is single nce',NCE)
            NCEs.append(NCE)
        print(torch.stack(NCEs).shape)
        return torch.stack(NCEs)
    
    def single_mutual_info_sampling(self, train_list, cl_class_label, ul_class_label):
        from audiomentations import Compose, Gain, AddGaussianNoise, PitchShift,TimeStretch,Shift
        from collections import Counter
        
        val_class_lable = set(cl_class_label) - set(ul_class_label)
        
        
        # Unlearning Part:class deleted will not be added into the memory bank

        train_df = pd.DataFrame(train_list)

         = Counter(infer_df['category'])
        print('Before Unpdate Statistics')
        for name, number in class_count.items():
            print(name, number)
        # mem_per_cls = self.memory_size // num_class  # kc: the number of the samples of each class

        batch_size = 8
        temperature = 0.05
        ret = []
        infer_loader = get_dataloader(infer_df, 'ref_youtube_audio', split='test', batch_size=batch_size, num_class=num_class,
                                      num_workers=8)
        augment = Compose([
            # Gain(min_gain_in_db=-12.0, max_gain_in_db=12.0),
            # AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.001),
            PitchShift(min_semitones=-0.5, max_semitones=0.5, p=0.5),
            # AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015),
            # TimeShift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
            # Shift(min_shift=-0.5, max_shift=0.5, p=0.5),
            # TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        ])

        Z, Z_, class_label_dic, predict_list = self.get_data(infer_loader, augment)
        assert (len(Z) == len(Z_) == len(predict_list))

        cur_NCEs = self.class_infoNCE(Z, Z_, class_label_dic, predict_list, temperature)

        path = SAVEDIR + '{}/task{}best_epoch2.pt'.format(self.mode,cur-1)
        self.change_model(path)

        pre_Z, pre_Z_, pre_class_label_dic, pre_predict_list = self.get_data(infer_loader, augment)
        assert (len(Z) == len(Z_) == len(predict_list))
        pre_NCEs = self.class_infoNCE(pre_Z, pre_Z_, pre_class_label_dic, pre_predict_list, temperature)

        path = SAVEDIR + '{}/task{}best_epoch2.pt'.format(self.mode,cur)
        self.change_model(path)

        # print(len(Z),len(Z_),len(predict_list),len(candidates))

        NCEs = pre_NCEs - cur_NCEs
        for candidate,NCE in zip(candidates,NCEs):candidate['NCE'] = NCE

        sample_df = pd.DataFrame(candidates)
        mem_per_cls = self.memory_size // cur_class_list  # kc: the number of the samples of each class


        for i in cur_class_list:
            cls_df = sample_df[(sample_df["category"].map(ytvos_category_dict)) == i]
            if len(cls_df) <= mem_per_cls:
                ret += cls_df.to_dict(orient="records")
            else:
                jump_idx = len(cls_df) // mem_per_cls
                uncertain_samples = cls_df.sort_values(by="NCE")[::jump_idx]
                ret += uncertain_samples[:mem_per_cls].to_dict(orient="records")

        num_rest_slots = self.memory_size - len(ret)
        if num_rest_slots > 0:
            self.logger.warning("Fill the unused slots by breaking the equilibrium.")
            ret += (
                sample_df[~sample_df.exp.isin(pd.DataFrame(ret).exp)]
                .sample(n=num_rest_slots)
                .to_dict(orient="records")
            )

        num_dups = pd.DataFrame(ret).exp.duplicated().sum()
        if num_dups > 0:
            self.logger.warning(f"Duplicated samples in memory: {num_dups}")


        class_count = Counter(pd.DataFrame(ret)['category'])
        print('After Unpdate Statistics')
        for name, number in class_count.items():
            print(name, number)

        return ret
    
    def double_mutual_info_sampling(self, candidates, cur, num_class):
        from audiomentations import Compose, Gain, AddGaussianNoise, PitchShift,TimeStretch,Shift
        from collections import Counter
        
        ulclass_list =   [None,self.ultask["task1"],self.ultask["task2"],self.ultask["task3"],self.ultask["task4"]]
        class_list = [self.cltask["task0"], self.cltask["task1"],self.cltask["task2"],self.cltask["task3"],self.cltask["task4"]]
        cl_class_list = []
        ul_class_list = []
        for i in range(num_class // 13):
            cur_class_list |= set(class_list[i])
            cur_class_list -= set(ulclass_list[i])
        cur_class_list.add(self.total_class_num-1)
        # Unlearning Part:class deleted will not be added into the memory bank

        infer_df = pd.DataFrame(candidates)

        class_count = Counter(infer_df['category'])
        print('Before Unpdate Statistics')
        for name, number in class_count.items():
            print(name, number)
        # mem_per_cls = self.memory_size // num_class  # kc: the number of the samples of each class

        batch_size = 8
        temperature = 0.05
        ret = []
        infer_loader = get_dataloader(infer_df, 'ref_youtube_audio', split='test', batch_size=batch_size, num_class=num_class,
                                      num_workers=8)
        augment = Compose([
            # Gain(min_gain_in_db=-12.0, max_gain_in_db=12.0),
            # AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.001),
            PitchShift(min_semitones=-0.5, max_semitones=0.5, p=0.5),
            # AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015),
            # TimeShift(min_fraction=-0.5, max_fraction=0.5, p=0.5),
            # Shift(min_shift=-0.5, max_shift=0.5, p=0.5),
            # TimeStretch(min_rate=0.9, max_rate=1.1, p=0.5),
        ])

        Z, Z_, class_label_dic, predict_list = self.get_data(infer_loader, augment)
        assert (len(Z) == len(Z_) == len(predict_list))

        cur_NCEs = self.class_infoNCE(Z, Z_, class_label_dic, predict_list, temperature)

        path = '/home/user/SED_Adaptation_Classifier-main/workspace/ref_youtube/MIO/iter{}epoch.pt'.format(cur - 1)
        self.change_model(path)

        pre_Z, pre_Z_, pre_class_label_dic, pre_predict_list = self.get_data(infer_loader, augment)
        assert (len(Z) == len(Z_) == len(predict_list))

        pre_NCEs = self.class_infoNCE(pre_Z, pre_Z_, pre_class_label_dic, pre_predict_list, temperature)

        path = '/home/user/SED_Adaptation_Classifier-main/workspace/ref_youtube/MIO/iter{}epoch.pt'.format(cur)
        self.change_model(path)

        # print(len(Z),len(Z_),len(predict_list),len(candidates))

        NCEs = pre_NCEs - cur_NCEs
        for candidate,NCE in zip(candidates,NCEs):candidate['NCE'] = NCE

        sample_df = pd.DataFrame(candidates)
         # kc: the number of the samples of each class in memory bank
        mem_per_cls = self.memory_size // len(cl_class_list)
        
        for_per_cls = self.forget_size// len(ul_class_list)
        


        for i in cur_class_list:
            cls_df = sample_df[(sample_df["category"].map(ytvos_category_dict)) == i]
            if len(cls_df) <= mem_per_cls:
                ret += cls_df.to_dict(orient="records")
            else:
                jump_idx = len(cls_df) // mem_per_cls
                uncertain_samples = cls_df.sort_values(by="NCE")[::jump_idx]
                ret += uncertain_samples[:mem_per_cls].to_dict(orient="records")

        num_rest_slots = self.memory_size - len(ret)
        if num_rest_slots > 0:
            logger.warning("Fill the unused slots by breaking the equilibrium.")
            ret += (
                sample_df[~sample_df.exp.isin(pd.DataFrame(ret).exp)]
                .sample(n=num_rest_slots)
                .to_dict(orient="records")
            )

        num_dups = pd.DataFrame(ret).exp.duplicated().sum()
        if num_dups > 0:
            logger.warning(f"Duplicated samples in memory: {num_dups}")


        # top_indices = np.argpartition(NCEs.cpu().numpy(), -2000)[-2000:]
        #
        # for index in top_indices:
        #     ret.append(candidates[index])

        class_count = Counter(pd.DataFrame(ret)['category'])
        print('After Unpdate Statistics')
        for name, number in class_count.items():
            print(name, number)

        return ret
    
    def train_with_datalist(self,train_list,test_list):
        
        train_loader, test_loader = get_train_test_dataloader(self.batch_size, self.n_worker, train_list, test_list)
        self.logger.info(f"In-memory samples: {len(self.memory_list)}")
        self.logger.info(f"Train samples: {len(train_list)}")
        self.logger.info(f"Test samples: {len(test_list)}")
        # logger.info(f"Model: {self.model}")
        self.logger.info(f"Optimizer: {self.optimizer}")
        acc_list = []
        best = {'acc': 0, 'epoch': 0,'f1_score':0}

        for epoch in range(self.epoch):
            mean_loss = 0
            for idx,batch_data_dict in enumerate(tqdm(train_loader)):
                batch_data_dict['waveform'] = batch_data_dict['waveform']
                batch_data_dict['target'] = batch_data_dict['target'].to(self.device)

                # Forward
                self.model.train()

                batch_output_dict = self.model(batch_data_dict['waveform'])
                """{'clipwise_output': (batch_size, classes_num), ...}"""
                batch_target_dict = {'target': batch_data_dict['target']}
                """{'target': (batch_size, classes_num)}"""
                # Loss
                
                loss = self.criterion(batch_output_dict, batch_target_dict)
                self.logger.info(f'Batch Training Initial Loss: {loss}')
                if idx % 10 == 0:
                    print(f'Epoch:{epoch},Batch {idx} Loss: {loss}')
                # Backwards
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                loss = loss.item()

                mean_loss += loss
            epoch_loss = mean_loss / len(train_loader)
            self.logger.info(f'Epoch {epoch} | Training Loss: {epoch_loss}')
            print(f'Epoch {epoch} | Training Loss: {epoch_loss}')
            # Evaluate
            test_statistics = self.evaluator.evaluate(test_loader)
            ave_f1_score = np.mean(test_statistics['f1_score'])
            ave_acc = np.mean(test_statistics['accuracy'])
            acc_list.append(ave_acc)
            self.logger.info(f"Epoch {epoch} | Evaluation Accuracy: {ave_acc}|Evaluation f1_score: {ave_f1_score}")
            self.logger.info(f'Current Accuracy: {ave_acc} in epoch {epoch}.|Current f1_score: {ave_f1_score} in epoch {epoch}.')
            print(f"Task {cur_iter} | Epoch {epoch} | Evaluation Accuracy: {ave_acc}|Evaluation f1_score: {ave_f1_score}|Evaluation precision {test_statistics['precision']}")
            

            if ave_f1_score > best['f1_score']:
                best['acc'] = ave_acc
                best['f1_score'] = ave_f1_score
                best['epoch'] = epoch
                self.logger.info(f'Best Accuracy: {ave_acc} in epoch {epoch}.|Best f1_score: {ave_f1_score} in epoch {epoch}.')
                selected_state_dict = {}
                for name, param in self.model.named_parameters():
                    if 'projector' in name or 'classifier' in name or 'fc' in name and ('encoder' not in name):
                        selected_state_dict[name] = param
                torch.save(selected_state_dict,SAVEDIR + '{}/task{}best_epoch{}.pt'.format(self.mode,cur_iter,epoch))
                self.counter = 0
            else:
                self.counter += 1
                self.logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}.')
                if self.counter >= self.patience:
                    break
        print(f"Task {cur_iter} | Best Epoch {best['epoch']} | Best Evaluation Accuracy: {best['acc']}|Evaluation f1_score: {best['f1_score']}")
        return 
    
    def train_with_forget_without_forget_bank(self, cur_iter):
        # For tets
        
        memory_bank = self.memory_list
        test_list = []
        for i in range(cur_iter + 1):
            train_list_,test_data_list_ = get_datalist(i)
            test_list += test_data_list_
        
        train_list,_ = get_datalist(cur_iter)
        train_list += memory_bank

        train_loader,test_loader = get_train_test_dataloader(self.batch_size, self.n_worker, train_list, test_list)
        cl_class_label,ul_class_label = [],[]

        best = {'cl_weighted_accuracy':0,'cl_accuracy':0,'ul_weighted_accuracy':0,'ul_accuracy':0,'epoch':0}
        for i in range(cur_iter + 1):
            cl_class_label += self.cltask[f'task{i}']
            ul_class_label += self.ultask[f'ul_task{i}']
        print('train loader length',len(train_loader),'test loader length',len(test_loader),'cl class label',cl_class_label,'ul class label',ul_class_label)
        for epoch in range(self.epoch):
            mean_loss = 0
            for idx,batch_data_dict in enumerate(tqdm(train_loader)):
                batch_data_dict['waveform'] = batch_data_dict['waveform']
                # print(batch_data_dict['target'],ul_class_label)
                batch_data_dict['target'] = self.forget_label_set(batch_data_dict['target'],ul_class_label)
                batch_data_dict['target'] = batch_data_dict['target'].to(self.device)

                # Forward
                self.model.train()

                batch_output_dict = self.model(batch_data_dict['waveform'])
                """{'clipwise_output': (batch_size, classes_num), ...}"""
                batch_target_dict = {'target': batch_data_dict['target']}
                """{'target': (batch_size, classes_num)}"""
                # Loss
                
                loss = self.criterion(batch_output_dict, batch_target_dict)
                self.logger.info(f'Batch Training Initial Loss: {loss}')
                if idx % 10 == 0:
                    print(f'Epoch:{epoch},Batch {idx} Loss: {loss}')
                # Backwards
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                loss = loss.item()

                mean_loss += loss
            epoch_loss = mean_loss / len(train_loader)
            
            print(f'Epoch {epoch} | Training Loss: {epoch_loss}')
            # Evaluate
            y_true,y_pred = self.evaluator.evaluate(test_loader)

            statistics = self.calculate_metrics(y_true,y_pred,cl_class_label,ul_class_label)

            print(f"Task {cur_iter} |  Epoch {epoch} | statistics {statistics}")
            if  statistics['cl_weighted_accuracy'] > best['cl_weighted_accuracy']:
                best['cl_weighted_accuracy'] = statistics['cl_weighted_accuracy']
                best['cl_accuracy'] = statistics['cl_accuracy']
                best['epoch'] = epoch
                # self.logger.info(f'Best Accuracy: {accuracy} in epoch {epoch}.|Best weighted_accuracy: {weighted_accuracy} in epoch {epoch}.')
                selected_state_dict = {}
                for name, param in self.model.named_parameters():
                    if 'projector' in name or 'classifier' in name or 'fc' in name and ('encoder' not in name):
                        selected_state_dict[name] = param
                torch.save(selected_state_dict,SAVEDIR + '{}/task{}best_epoch{}.pt'.format(self.mode,cur_iter,epoch))
                self.counter = 0
            else:
                self.counter += 1
                self.logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}.')
                if self.counter >= self.patience:
                    break
        print(f"Task {cur_iter} | Best Epoch {best['epoch']} | Best Accuracy: {best['cl_accuracy']}|Best weighted_accuracy: {best['weighted_accuracy']}")
        return train_list,test_list,cl_class_label,ul_class_label
        
    def calculate_metrics(self,y_true,y_pred,cl_class_label,ul_class_label):
        statistics = {'cl_weighted_accuracy':0,'ul_weighted_accuracy':0,'cl_accuracy':0,'ul_accuracy':0}
        cl_y_true,cl_y_pred = [],[]
        ul_y_true,ul_y_pred = [],[]
        for y_t,y_d in zip(y_true,y_pred):
            if y_t in cl_class_label and y_t not in ul_class_label:
                cl_y_true.append(y_t)
                cl_y_pred.append(y_d)
            else:
                ul_y_true.append(y_t)
                ul_y_pred.append(y_d)

        cl_weighted_accuracy = balanced_accuracy_score(cl_y_true,cl_y_pred)
        ul_weighted_accuracy = balanced_accuracy_score(ul_y_true,ul_y_pred)

        cl_accuracy = accuracy_score(cl_y_true,cl_y_pred)
        ul_accuracy = accuracy_score(ul_y_true,ul_y_pred)

        statistics['ul_accuracy'] = ul_accuracy
        statistics['cl_accuracy'] = cl_accuracy

        statistics['cl_weighted_accuracy'] = cl_weighted_accuracy
        statistics['ul_weighted_accuracy'] = ul_weighted_accuracy

        return statistics

    def get_cl_ul_class_label(self,cur_iter):
        cl_class_label = []
        ul_class_label = []
        for i in range(cur_iter + 1):
            cl_class_label += self.cltask[f'task{i}']
            ul_class_label += self.ultask[f'ul_task{i}']
        return cl_class_label,ul_class_label

In [4]:
# train_list,test_list = get_datalist(0)
# train_loader ,test_loader = get_train_test_dataloader(16,0,train_list,test_list)

# for train in train_loader:
#     print(train)
clul = CLUL()

def train_total():
    for task_id in range(5):
        train_list,test_list,cl_class_label,ul_class_label = clul.train_with_forget_without_forget_bank(task_id)
        if task_id == 0:
            clul.equal_class_sampling(train_list)
        else:
            clul.single_mutual_info_sampling(train_list,cl_class_label,ul_class_label)

        



  target = torch.FloatTensor(target)
Evaluation starting ...: 100%|██████████| 87/87 [03:04<00:00,  2.13s/it]

Returned target_acc and clipwise_output_acc





In [9]:
train_list1,test_list1 = clul.get_train_test_datalist(0)
cl_class_label1,ul_class_label1 = clul.get_cl_ul_class_label(0)

train_list2,test_list2 = clul.get_train_test_datalist(1)
cl_class_label2,ul_class_label2 = clul.get_cl_ul_class_label(1)

train_list = train_list1 + train_list2
cl_class_label = cl_class_label1 + cl_class_label2
ul_class_label = ul_class_label1 + ul_class_label2

train_df = pd.DataFrame(train_list)
train_df['category_id'] = ytvos_category_dict[]
train_list,cl_class_label,ul_class_label

([{'video': 'e6261d2348',
   'exp': 'a brown duck swimming behind a white one leftwards',
   'audio': '2',
   'category': 'duck'},
  {'video': 'ed9ff4f649',
   'exp': 'a gray with white duck, laying in the water, on the right side of the screen',
   'audio': '2',
   'category': 'duck'},
  {'video': 'b220939e93',
   'exp': 'a brown duck swimming right by a white one by the shore',
   'audio': '1',
   'category': 'duck'},
  {'video': '082900c5d4',
   'exp': 'a goose swimming to the left and putting it s head in the water',
   'audio': '0',
   'category': 'duck'},
  {'video': 'f22d21f1f1',
   'exp': 'a white headed and white chested duck in the center of three other ducks',
   'audio': '4',
   'category': 'duck'},
  {'video': 'ca91e99105',
   'exp': 'a white swan in the water',
   'audio': '1',
   'category': 'duck'},
  {'video': 'efe5ac6901',
   'exp': 'the third flamingo walking towards exit out of the water',
   'audio': '5',
   'category': 'duck'},
  {'video': '8f1600f7f6',
   'exp': 

In [8]:
train_list,test_list = get_datalist(0)
train_list

[{'video': 'e6261d2348',
  'exp': 'a brown duck swimming behind a white one leftwards',
  'audio': '2',
  'category': 'duck'},
 {'video': 'ed9ff4f649',
  'exp': 'a gray with white duck, laying in the water, on the right side of the screen',
  'audio': '2',
  'category': 'duck'},
 {'video': 'b220939e93',
  'exp': 'a brown duck swimming right by a white one by the shore',
  'audio': '1',
  'category': 'duck'},
 {'video': '082900c5d4',
  'exp': 'a goose swimming to the left and putting it s head in the water',
  'audio': '0',
  'category': 'duck'},
 {'video': 'f22d21f1f1',
  'exp': 'a white headed and white chested duck in the center of three other ducks',
  'audio': '4',
  'category': 'duck'},
 {'video': 'ca91e99105',
  'exp': 'a white swan in the water',
  'audio': '1',
  'category': 'duck'},
 {'video': 'efe5ac6901',
  'exp': 'the third flamingo walking towards exit out of the water',
  'audio': '5',
  'category': 'duck'},
 {'video': '8f1600f7f6',
  'exp': 'a duck on the corner just lik

In [37]:
from enum import Enum
from sklearn.metrics import f1_score,precision_recall_curve,precision_score,recall_score,accuracy_score,balanced_accuracy_score
y_true = [0, 0, 0, 1, 1, 1,   2]
y_pred = [0, 0, 0, 1, 1, 1,   4]
class AverageMethod(str, Enum):
    MICRO = 'micro'
    WEIGHTED = 'weighted'
    MACRO = 'macro'
def evaluate(y_true,y_pred,average:AverageMethod):
    statistics = {}
    statistics['f1_score'] = f1_score(y_true,y_pred,average=average.value)
    statistics['precision'] = precision_score(y_true,y_pred,average=average.value)
    statistics['recall'] = recall_score(y_true,y_pred,average=average.value)
    statistics['accuracy'] = accuracy_score(y_true,y_pred)
    return statistics
evaluate(y_true,y_pred,AverageMethod.MACRO),balanced_accuracy_score(y_true,y_pred)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


({'f1_score': 0.5,
  'precision': 0.5,
  'recall': 0.5,
  'accuracy': 0.8571428571428571},
 0.6666666666666666)

In [52]:
t['target']



forget_label_set(t['target'],[17,4,5])
t['target']

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])