# Cycle Self Training Implementation on DCASE TAU 2020 Dataset

In [1]:
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F

import librosa
from hear21passt.base import get_basic_model
import datetime
import matplotlib.pyplot as plt
from collections import defaultdict
from typing import Tuple, Optional, List, Dict
from types import SimpleNamespace

In [2]:
sys.path.append(os.path.dirname(os.getcwd()))
import common.audio.datasets as datasets
from common.audio.transforms import TransformFixMatch
from common.audio.transforms import WeakAugment
from common.audio.transforms import StrongAugment
from common.utils.data import ForeverDataIterator
from common.utils.metric import accuracy, ConfusionMatrix
from common.utils.meter import AverageMeter, ProgressMeter
from common.utils.logger import CompleteLogger
from common.utils.analysis import collect_feature, tsne, a_distance
from common.tools.randaugment import rand_augment_transform, GaussianBlur
from common.tools.sam import SAM

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)

Using Device: cuda


In [4]:
args = SimpleNamespace(
    root = './data/dcase',
    data = 'DCASE',          
    source = 'source',      
    target = 'target',
    temperature = 2.0,
    alpha = 1.9,
    trade_off = 0.08,
    trade_off1 = 0.5,
    trade_off3 = 0.5,
    threshold = 0.97,
    rho = 0.5,
    batch_size = 3,
    lr = 0.0005,
    lr_gamma = 0.001,
    lr_decay = 0.75,
    momentum = 0.9,
    weight_decay = 1e-3,
    workers = 2,
    epochs = 50,
    early = 45,
    iters_per_epoch = 1000,
    val_interval = 10,
    print_freq = 100,        
    seed = None,             
    per_class_eval = False,  
    log = 'logs/dcase',             
    phase = 'train', 
    sample_rate = 32000,
    clip_length = 10,
    num_cls = 10,
    device = device,
    
)
print("Arguments set for CST experiment (DCASE TAU 2020):")
for k, v in args.__dict__.items():
    print(f"{k} = {v}")


Arguments set for CST experiment (DCASE TAU 2020):
root = ./data/dcase
data = DCASE
source = source
target = target
temperature = 2.0
alpha = 1.9
trade_off = 0.08
trade_off1 = 0.5
trade_off3 = 0.5
threshold = 0.97
rho = 0.5
batch_size = 3
lr = 0.0005
lr_gamma = 0.001
lr_decay = 0.75
momentum = 0.9
weight_decay = 0.001
workers = 2
epochs = 50
early = 45
iters_per_epoch = 1000
val_interval = 10
print_freq = 100
seed = None
per_class_eval = False
log = logs/dcase
phase = train
sample_rate = 32000
clip_length = 10
num_cls = 10
device = cuda


In [5]:
def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:
    epsilon = 1e-5
    H = -predictions * torch.log(predictions + epsilon)
    H = H.sum(dim=1)
    if reduction == 'mean':
        return H.mean()
    else:
        return H

class TsallisEntropy(nn.Module):
    def __init__(self, temperature: float, alpha: float):
        super(TsallisEntropy, self).__init__()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        N, C = logits.shape
        pred = F.softmax(logits / self.temperature, dim=1) 
        entropy_weight = entropy(pred).detach()
        entropy_weight = 1 + torch.exp(-entropy_weight)
        entropy_weight = (N * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1)  
        sum_dim = torch.sum(pred * entropy_weight, dim = 0).unsqueeze(dim=0)
        return 1 / (self.alpha - 1) * torch.sum((1 / torch.mean(sum_dim) - torch.sum(pred ** self.alpha / sum_dim * entropy_weight, dim = -1)))


In [6]:
class AudioClassifier (nn.Module):
    def __init__(self, backbone, num_classes, bottleneck_dim=256, finetune=True):
        super(AudioClassifier, self).__init__()
        self.backbone = backbone  
        self.bottleneck = nn.Sequential(
            nn.Linear(768, bottleneck_dim),  
            nn.BatchNorm1d(bottleneck_dim),
            nn.ReLU()
        )
        self.head = nn.Linear(bottleneck_dim, num_classes)
        self.finetune = finetune

        if not finetune:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.bottleneck(features)
        outputs = self.head(embeddings)
        return outputs, embeddings

    def get_parameters(self, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
            {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
            {"params": self.head.parameters(), "lr": 1.0 * base_lr},
        ]

        return params

In [7]:
# Data loading code
train_transform = WeakAugment()
unlabeled_transform = TransformFixMatch()
val_transform = WeakAugment()  

dataset = datasets.__dict__[args.data]

train_source_dataset = dataset(root=args.root, task=args.source, split='train', transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True)

train_target_dataset = dataset(root=args.root, task=args.target, split='train', transform=unlabeled_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True)
    
train_size = int(0.9 * len(train_source_dataset))
val_size = len(train_source_dataset) - train_size
train_source_dataset, val_source_dataset = torch.utils.data.random_split(train_source_dataset, [train_size, val_size])

train_size = int(0.9 * len(train_target_dataset))
val_size = len(train_target_dataset) - train_size
train_target_dataset, val_target_dataset = torch.utils.data.random_split(train_target_dataset, [train_size, val_size])

val_source_dataset = dataset(root=args.root, task=args.source, split='train', transform=val_transform)
val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

val_target_dataset = dataset(root=args.root, task=args.target, split='train', transform=val_transform)
val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

test_source_dataset = dataset(root=args.root, task=args.source, split='test', transform=val_transform)
test_source_loader = DataLoader(test_source_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

test_target_dataset = dataset(root=args.root, task=args.target, split='test', transform=val_transform)
test_target_loader = DataLoader(test_target_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)

In [8]:
# create model
print("=> using pre-trained model PaSST")
backbone = get_basic_model(mode="embed_only")
backbone.eval() 
classifier = AudioClassifier(backbone, num_classes=10).to(device)

=> using pre-trained model PaSST


 Loading PASST TRAINED ON AUDISET 


PaSST(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
  

In [9]:
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True
    
    # define optimizer and lr scheduler
    base_optimizer = SGD
    optimizer = SAM(classifier.get_parameters(), base_optimizer, lr = args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay, adaptive = True, rho = args.rho)
    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))

    # define loss function
    ts_loss = TsallisEntropy(temperature=args.temperature, alpha = args.alpha)

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader, feature_extractor, device)
        target_feature = collect_feature(train_target_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature, device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc_s, _ = validate(test_source_loader, classifier, args)
        print(acc_s)

        acc_t, _ = validate(test_target_loader, classifier, args)
        print(acc_t)
        return

    # start training
    best_acc_s = 0.
    best_acc_t = 0.
    for epoch in range(min(args.epochs, args.early)):
        print("lr:", lr_scheduler.get_last_lr()[0])
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, ts_loss, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        if (epoch + 1) % args.val_interval == 0 or epoch == args.epochs - 1:
            acc_s, device_results_s = validate(val_source_loader, classifier, args)
            acc_t, device_results_t = validate(val_target_loader, classifier, args)
            
    
            # remember best acc@1 and save checkpoint
            torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
            if acc_s > best_acc_s:
                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
            best_acc_s = max(acc_s, best_acc_s)
    
            if acc_t > best_acc_t:
                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
            best_acc_t = max(acc_t, best_acc_t)

            print("Best Source Accuracy = {:3.1f}".format(best_acc_s))
            print("Best Target Accuracy = {:3.1f}".format(best_acc_t))

        else:
            torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))

    

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc_s, _ = validate(test_source_loader, classifier, args)
    print("Test Source Accuracy = {:3.1f}".format(acc_s))

    acc_t, _ = validate(test_target_loader, classifier, args)
    print("Test Target Accuracy = {:3.1f}".format(acc_t))

    logger.close()


In [10]:
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
          model: AudioClassifier, ts: TsallisEntropy, optimizer: SGD,
          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':3.1f')
    data_time = AverageMeter('Data', ':3.1f')
    losses = AverageMeter('Loss', ':3.2f')
    trans_losses = AverageMeter('Trans Loss', ':3.2f')
    rev_losses = AverageMeter('CST Loss', ':3.2f')
    fix_losses = AverageMeter('Fix Loss', ':3.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')

    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, trans_losses, rev_losses, fix_losses, cls_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i in range(args.iters_per_epoch):
        x_s, labels_s, _, _ = next(train_source_iter)
        (x_t, x_t_u), _, _, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        x_t_u = x_t_u.to(device)
        labels_s = labels_s.to(device)

        # measure data loading time
        data_time.update(time.time() - end)

        # compute output
        x = torch.cat((x_s, x_t), dim=0)
        y, f = model(x)
        y_t_u, _ = model(x_t_u)

        f_s, f_t = f.chunk(2, dim=0)
        y_s, y_t = y.chunk(2, dim=0)

        # generate target pseudo-labels
        max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
        Lu = (F.cross_entropy(y_t_u, pred_u,
                              reduction='none') * max_prob.ge(args.threshold).float().detach()).mean()

        # compute cst
        target_data_train_r = f_t
        target_data_train_r = target_data_train_r / (torch.norm(target_data_train_r, dim = -1).reshape(target_data_train_r.shape[0], 1))
        target_data_test_r = f_s
        target_data_test_r = target_data_test_r / (torch.norm(target_data_test_r, dim = -1).reshape(target_data_test_r.shape[0], 1))
        target_gram_r = torch.clamp(target_data_train_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        target_kernel_r = target_gram_r
        test_gram_r = torch.clamp(target_data_test_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        test_kernel_r = test_gram_r
        target_train_label_r = torch.nn.functional.one_hot(pred_u, args.num_cls) - 1 / float(args.num_cls) 
        target_test_pred_r = test_kernel_r.mm(torch.inverse(target_kernel_r + 0.001 * torch.eye(args.batch_size).cuda())).mm(target_train_label_r)
        reverse_loss = nn.MSELoss()(target_test_pred_r, torch.nn.functional.one_hot(labels_s, args.num_cls) - 1 / float(args.num_cls)) 

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = ts(y_t)

        if Lu != 0:
            loss = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1 + Lu * args.trade_off3
        else: 
            loss = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        trans_losses.update(transfer_loss.item(), x_s.size(0))
        rev_losses.update(reverse_loss.item(), x_s.size(0))
        fix_losses.update(Lu.item(), x_s.size(0))

        # compute gradient and do the first SGD step
        loss.backward()
        optimizer.first_step(zero_grad=True)
        lr_scheduler.step()

        # compute gradient and do the second SGD step
        y, f = model(x)
        y_t_u, _ = model(x_t_u)

        f_s, f_t = f.chunk(2, dim=0)
        y_s, y_t = y.chunk(2, dim=0)

        # generate target pseudo-labels
        max_prob, pred_u = torch.max(F.softmax(y_t), dim=-1)
        Lu = (F.cross_entropy(y_t_u, pred_u,
                              reduction='none') * max_prob.ge(args.threshold).float().detach()).mean()

        # compute cst
        target_data_train_r = f_t
        target_data_train_r = target_data_train_r / (torch.norm(target_data_train_r, dim = -1).reshape(target_data_train_r.shape[0], 1))
        target_data_test_r = f_s
        target_data_test_r = target_data_test_r / (torch.norm(target_data_test_r, dim = -1).reshape(target_data_test_r.shape[0], 1))
        target_gram_r = torch.clamp(target_data_train_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        target_kernel_r = target_gram_r
        test_gram_r = torch.clamp(target_data_test_r.mm(target_data_train_r.transpose(dim0 = 1, dim1 = 0)),-0.99999999,0.99999999)
        test_kernel_r = test_gram_r
        target_train_label_r = torch.nn.functional.one_hot(pred_u, args.num_cls) - 1 / float(args.num_cls) 
        target_test_pred_r = test_kernel_r.mm(torch.inverse(target_kernel_r + 0.001 * torch.eye(args.batch_size).cuda())).mm(target_train_label_r)
        reverse_loss = nn.MSELoss()(target_test_pred_r, torch.nn.functional.one_hot(labels_s, args.num_cls) - 1 / float(args.num_cls)) 

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = ts(y_t)

        if Lu != 0:
            loss1 = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1 + Lu * args.trade_off3
        else: 
            loss1 = cls_loss + transfer_loss * args.trade_off + reverse_loss * args.trade_off1

        loss1.backward()
        optimizer.second_step(zero_grad=True)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)


In [11]:
def evaluate_by_device(model, test_target_loader, device):
    model.eval()
    device_results = defaultdict(lambda: {'correct': 0, 'total': 0})
    
    with torch.no_grad():
        for audio, labels, domains, devices in test_target_loader:
            audio = audio.to(device)
            labels = labels.to(device)
            
            outputs = model(audio)
            _, predicted = torch.max(outputs.data, 1)
            
            for i, dev in enumerate(devices):
                device_results[dev]['total'] += 1
                if predicted[i] == labels[i]:
                    device_results[dev]['correct'] += 1
    
    results = {}
    for dev in device_results:
        correct = device_results[dev]['correct']
        total = device_results[dev]['total']
        accuracy = 100.0 * correct / total
        results[dev] = {'accuracy': accuracy, 'total': total}
        print(f'Device {dev}: Accuracy = {accuracy:.2f}% ({correct}/{total})')
    
    return results

In [12]:
def validate(val_loader: DataLoader, model: AudioClassifier, args: argparse.Namespace) -> float:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()
    device_results = defaultdict(lambda: {'correct': 0, 'total': 0})
    if args.per_class_eval:
        classes = val_loader.dataset.classes
        confmat = ConfusionMatrix(len(classes))
    else:
        confmat = None

    with torch.no_grad():
        end = time.time()
        for i, (audio, target, _, devices) in enumerate(val_loader):
            audio = audio.to(args.device)
            target = target.to(args.device)

            # compute output
            output, _ = model(audio)
            loss = F.cross_entropy(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            if confmat:
                confmat.update(target, output.argmax(1))
            losses.update(loss.item(), audio.size(0))
            top1.update(acc1.item(), audio.size(0))
            top5.update(acc5.item(), audio.size(0))

            # Update device-specific results
            pred = output.argmax(dim=1)
            for j, dev in enumerate(devices):
                device_results[dev]['total'] += 1
                if pred[j] == target[j]:
                    device_results[dev]['correct'] += 1

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))
        if confmat:
            print(confmat.format(classes))

        print("\nDevice-specific results:")
        for dev in device_results:
            correct = device_results[dev]['correct']
            total = device_results[dev]['total']
            acc = 100.0 * correct / total
            print(f'Device {dev}: {acc:.2f}% ({correct}/{total})')
            
    return top1.avg, device_results


In [None]:
main(args)

namespace(root='./data/dcase', data='DCASE', source='source', target='target', temperature=2.0, alpha=1.9, trade_off=0.08, trade_off1=0.5, trade_off3=0.5, threshold=0.97, rho=0.5, batch_size=3, lr=0.0005, lr_gamma=0.001, lr_decay=0.75, momentum=0.9, weight_decay=0.001, workers=2, epochs=50, early=45, iters_per_epoch=1000, val_interval=10, print_freq=100, seed=None, per_class_eval=False, log='logs/dcase', phase='train', sample_rate=32000, clip_length=10, num_cls=10, device=device(type='cuda'))
lr: 5e-05
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:875.)
  return _VF.stft(  # type: ignore[attr-defined]
  with torch.cuda.amp.autocast(enabled=False):
x torch.Size([6, 1, 128, 1000])
self.norm(x) torch.Size([6, 768, 12, 99])
 patch_embed :  torch.Size([6, 768, 12, 99])
 self.time_new_pos_embed.shape torch.Size([1, 768, 1, 99])
 self.freq_new_pos_embed.shape torch.Size