In [None]:
# !pip install pip==24.0
# !pip install fairseq
!pip install omegaconf
!pip install conformer
!pip install ml-collections
# !wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt

Collecting ml-collections
  Using cached ml_collections-0.1.1-py3-none-any.whl
Installing collected packages: ml-collections
Successfully installed ml-collections-0.1.1


## Temporal Channel Modelling

In [None]:
import torch
import torch.nn as nn
# import fairseq
from conformer import ConformerBlock
from torch.nn.modules.transformer import _get_clones
from torch import Tensor

def sinusoidal_embedding(n_channels, dim):
    pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
                            for p in range(n_channels)])
    pe[:, 0::2] = torch.sin(pe[:, 0::2])
    pe[:, 1::2] = torch.cos(pe[:, 1::2])
    return pe.unsqueeze(0)

class MyConformer(nn.Module):
  def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1):
    super(MyConformer, self).__init__()
    self.dim_head=int(emb_size/heads)
    self.dim=emb_size
    self.heads=heads
    self.kernel_size=kernel_size
    self.n_encoders=n_encoders
    self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False)
    self.encoder_blocks=_get_clones( ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads,
    ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size),
    n_encoders)
    self.class_token = nn.Parameter(torch.rand(1, emb_size))
    self.fc5 = nn.Linear(emb_size, 2)

  def forward(self, x, device): # x shape [bs, tiempo, frecuencia]
    x = x + self.positional_emb[:, :x.size(1), :]
    x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size]
    print(f"Shape after adding class token: {x.shape}")
    list_attn_weight = []
    for layer in self.encoder_blocks:
            x, attn_weight = layer(x) #[bs,1+tiempo,emb_size]
            list_attn_weight.append(attn_weight)
    # embedding=x[:,0,:] #[bs, emb_size]
    # Debug the shape of x
    print(f"Shape of x before extracting embedding: {x.shape}")  # Should be [batch_size, 1+time, emb_size]

    # Ensure that x has the correct shape (3D) before indexing
    if len(x.shape) == 3:
      embedding = x[:, 0, :]  # [batch_size, emb_size]
    else:
      raise ValueError(f"Unexpected tensor shape: {x.shape}")

    out=self.fc5(embedding) #[bs,2]
    return out, list_attn_weight

class SSLModel(nn.Module): #W2V
    def __init__(self,device):
        super(SSLModel, self).__init__()
        cp_path = 'xlsr2_300m.pt'   # Change the pre-trained XLSR model path.
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
        self.model = model[0]
        self.device=device
        self.out_dim = 1024
        return

    def extract_feat(self, input_data):
        # put the model to GPU if it not there
        if next(self.model.parameters()).device != input_data.device \
           or next(self.model.parameters()).dtype != input_data.dtype:
            self.model.to(input_data.device, dtype=input_data.dtype)
            self.model.train()

        # input should be in shape (batch, length)
        if input_data.ndim == 3:
            input_tmp = input_data[:, :, 0]
        else:
            input_tmp = input_data

        # [batch, length, dim]
        emb = self.model(input_tmp, mask=False, features_only=True)['x']
        return emb

class Model(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        self.device=device
        ####
        # create network wav2vec 2.0
        ####
        self.ssl_model = SSLModel(self.device)
        self.LL = nn.Linear(1024, args.emb_size)
        print('W2V + Conformer')
        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.selu = nn.SELU(inplace=True)
        self.conformer=MyConformer(emb_size=args.emb_size, n_encoders=args.num_encoders,
        heads=args.heads, kernel_size=args.kernel_size)
    def forward(self, x):
        #-------pre-trained Wav2vec model fine tunning ------------------------##
        x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
        x=self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) (bs, 208, 256)
        x = x.unsqueeze(dim=1) # add channel #(bs, 1, frame_number, 256)
        x = self.first_bn(x)
        x = self.selu(x)
        x = x.squeeze(dim=1)
        out, attn_score =self.conformer(x,self.device)
        return out, attn_score


In [None]:
import numpy as np
import soundfile as sf

# Parameters
duration = 10.0  # seconds
sample_rate = 16000
frequency = 440.0  #

# Time axis
t = np.linspace(0.0, duration, int(sample_rate * duration))

audio_data = 0.5 * np.sin(2 * np.pi * frequency * t)

# Save the audio data as a .flac file in Colab
file_path = 'sa.flac'
sf.write(file_path, audio_data, sample_rate)

print(f"Audio file saved as {file_path}")


Audio file saved as sa.flac


In [None]:
import torch
class MyConformer(nn.Module):
    def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1):
        super(MyConformer, self).__init__()
        self.dim_head = int(emb_size / heads)
        self.dim = emb_size
        self.heads = heads
        self.kernel_size = kernel_size
        self.n_encoders = n_encoders
        self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False)
        self.encoder_blocks = _get_clones(ConformerBlock(dim=emb_size, dim_head=self.dim_head, heads=heads,
                                                         ff_mult=ffmult, conv_expansion_factor=exp_fac,
                                                         conv_kernel_size=kernel_size), n_encoders)
        self.class_token = nn.Parameter(torch.rand(1, emb_size))
        self.fc5 = nn.Linear(emb_size, 2)

    def forward(self, x, device):
        # Add positional embedding
        x = x + self.positional_emb[:, :x.size(1), :]
        x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])  # [bs,1+time,emb_size]
        print(f"Shape after adding class token: {x.shape}")


        list_attn_weight = []

        # Process through ConformerBlocks
        for layer in self.encoder_blocks:

            output = layer(x)
            x = output



        embedding = x[:, 0, :]  # [bs, emb_size]
        out = self.fc5(embedding)  # [bs, 2]
        return out, list_attn_weight

batch_size = 1
time_steps = 100
emb_size = 128


random_input = torch.randn(batch_size, time_steps, emb_size)


conformer_model = MyConformer(emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1)


device = torch.device('cpu')


conformer_model.eval()
with torch.no_grad():
    output, attention_weights = conformer_model(random_input, device)


print("Output shape:", output.shape)
print("Output:", output)
print("Attention Weights:", attention_weights)


Shape after adding class token: torch.Size([1, 101, 128])
Output shape: torch.Size([1, 2])
Output: tensor([[0.9717, 0.1722]])
Attention Weights: []


# Image

In [None]:
# coding=utf-8

import copy
import math
from os.path import join as pjoin

import torch
import torch.nn as nn
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair

ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"


def np2th(weights, conv=False):
	"""Possibly convert HWIO to OIHW."""
	if conv:
		weights = weights.transpose([3, 2, 0, 1])
	return torch.from_numpy(weights)


def swish(x):
	return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


class Mlp(nn.Module):
	def __init__(self, config):
		super(Mlp, self).__init__()
		self.fc1 = Linear(config.hidden_size, config.mlp_dim)
		self.fc2 = Linear(config.mlp_dim, config.hidden_size)
		self.act_fn = ACT2FN["gelu"]
		self.dropout = Dropout(config.dropout_rate)

		self._init_weights()

	def _init_weights(self):
		nn.init.xavier_uniform_(self.fc1.weight)
		nn.init.xavier_uniform_(self.fc2.weight)
		nn.init.normal_(self.fc1.bias, std=1e-6)
		nn.init.normal_(self.fc2.bias, std=1e-6)

	def forward(self, x):
		x = self.fc1(x)
		x = self.act_fn(x)
		x = self.dropout(x)
		x = self.fc2(x)
		x = self.dropout(x)
		return x


class Embeddings(nn.Module):
	"""Construct the embeddings from patch, position embeddings.
	"""

	def __init__(self, config, img_size, in_channels=3):
		super(Embeddings, self).__init__()
		img_size = _pair(img_size)

		patch_size = _pair(config.patches)
		n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
		self.patch_embeddings = Conv2d(in_channels=in_channels,
		                               out_channels=config.hidden_size,
		                               kernel_size=patch_size,
		                               stride=patch_size)
		self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size))
		self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

		self.dropout = Dropout(config.dropout_rate)

	def forward(self, x):
		B = x.shape[0]
		cls_tokens = self.cls_token.expand(B, -1, -1)

		x = self.patch_embeddings(x)
		x = x.flatten(2)
		x = x.transpose(-1, -2)
		x = torch.cat((cls_tokens, x), dim=1)

		embeddings = x + self.position_embeddings
		embeddings = self.dropout(embeddings)
		return embeddings


class Encoder(nn.Module):
	def __init__(self, config):
		super(Encoder, self).__init__()
		self.layer = nn.ModuleList()
		# for _ in range(config.num_layers):
		for _ in range(config.num_layers + 1):
			layer = Block(config)
			self.layer.append(copy.deepcopy(layer))

	def forward(self, hidden_states):
		# attmap = []
		for layer in self.layer:
			hidden_states, weights = layer(hidden_states)
		# print(weights.shape)
		# attmap.append(weights)
		return hidden_states


class Transformer(nn.Module):
	def __init__(self, config, img_size):
		super(Transformer, self).__init__()
		self.embeddings = Embeddings(config, img_size=img_size)
		self.encoder = Encoder(config)

	def forward(self, input_ids):
		embedding_output = self.embeddings(input_ids)
		part_encoded = self.encoder(embedding_output)
		return part_encoded


class LabelSmoothing(nn.Module):
	"""
	NLL loss with label smoothing.
	"""

	def __init__(self, smoothing=0.0):
		"""
		Constructor for the LabelSmoothing module.
		param smoothing: label smoothing factor
		"""
		super(LabelSmoothing, self).__init__()
		self.confidence = 1.0 - smoothing
		self.smoothing = smoothing

	def forward(self, x, target):
		logprobs = torch.nn.functional.log_softmax(x, dim=-1)
		nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
		nll_loss = nll_loss.squeeze(1)
		smooth_loss = -logprobs.mean(dim=-1)
		loss = self.confidence * nll_loss + self.smoothing * smooth_loss
		return loss.mean()


class Attention(nn.Module):
	def __init__(self, config, assess=False):
		super(Attention, self).__init__()
		self.assess = assess
		self.num_attention_heads = config.num_heads
		self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
		self.all_head_size = self.num_attention_heads * self.attention_head_size

		self.query = Linear(config.hidden_size, self.all_head_size)
		self.key = Linear(config.hidden_size, self.all_head_size)
		self.value = Linear(config.hidden_size, self.all_head_size)

		self.out = Linear(config.hidden_size, config.hidden_size)
		self.attn_dropout = Dropout(config.att_dropout)
		self.proj_dropout = Dropout(config.att_dropout)

		self.softmax = Softmax(dim=-1)

	def transpose_for_scores(self, x):
		new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
		x = x.view(*new_x_shape)
		return x.permute(0, 2, 1, 3)

	def forward(self, hidden_states):
		mixed_query_layer = self.query(hidden_states)
		mixed_key_layer = self.key(hidden_states)
		mixed_value_layer = self.value(hidden_states)

		query_layer = self.transpose_for_scores(mixed_query_layer)
		key_layer = self.transpose_for_scores(mixed_key_layer)
		value_layer = self.transpose_for_scores(mixed_value_layer)

		attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
		attention_scores = attention_scores / math.sqrt(self.attention_head_size)
		attention_probs = self.softmax(attention_scores)
		weights = attention_probs
		attention_probs = self.attn_dropout(attention_probs)

		context_layer = torch.matmul(attention_probs, value_layer)
		context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
		new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
		context_layer = context_layer.view(*new_context_layer_shape)
		attention_output = self.out(context_layer)
		attention_output = self.proj_dropout(attention_output)
		if self.assess:
			return attention_output, weights, attention_scores
		else:
			return attention_output, weights


class Block(nn.Module):
	def __init__(self, config, assess=False):
		super(Block, self).__init__()
		self.assess = assess
		self.hidden_size = config.hidden_size
		self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
		self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
		self.ffn = Mlp(config)
		self.attn = Attention(config, self.assess)

	def forward(self, x):
		h = x
		x = self.attention_norm(x)
		if self.assess:
			x, weights, score = self.attn(x)
		else:
			x, weights = self.attn(x)
		x = x + h

		h = x
		x = self.ffn_norm(x)
		x = self.ffn(x)
		x = x + h
		return x, weights

	def load_from(self, weights, n_block):
		ROOT = f"Transformer/encoderblock_{n_block}"
		with torch.no_grad():
			query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size,
			                                                                       self.hidden_size).t()
			key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
			value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size,
			                                                                       self.hidden_size).t()
			out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size,
			                                                                       self.hidden_size).t()

			query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
			key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
			value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
			out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

			self.attn.query.weight.copy_(query_weight)
			self.attn.key.weight.copy_(key_weight)
			self.attn.value.weight.copy_(value_weight)
			self.attn.out.weight.copy_(out_weight)
			self.attn.query.bias.copy_(query_bias)
			self.attn.key.bias.copy_(key_bias)
			self.attn.value.bias.copy_(value_bias)
			self.attn.out.bias.copy_(out_bias)

			mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
			mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
			mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
			mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

			self.ffn.fc1.weight.copy_(mlp_weight_0)
			self.ffn.fc2.weight.copy_(mlp_weight_1)
			self.ffn.fc1.bias.copy_(mlp_bias_0)
			self.ffn.fc2.bias.copy_(mlp_bias_1)

			self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
			self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
			self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
			self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

# if __name__ == '__main__':
# from core.vit import *
# config = get_b16_config()
import ml_collections
def get_b16_config():
	"""Returns the ViT-B/16 configuration."""
	config = ml_collections.ConfigDict()
	config.patches = (16, 16)
	config.hidden_size = 768
	config.mlp_dim = 3072
	config.num_heads = 12
	config.num_layers = 12
	config.att_dropout = 0.0
	config.dropout_rate = 0.1
	config.classifier = 'token'
	return config


In [None]:
import time

import numpy as np
from scipy import ndimage
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
# from models.modules import *
# from models.vit import get_b16_config


class InterEnsembleLearningTransformer(nn.Module):
	def __init__(self, config, img_size=448, num_classes=2, dataset='cub', smooth_value=0.,
	             loss_alpha=0.4, cam=True, dsm=True, fix=True, update_warm=500,
	             vote_perhead=24, total_num=126, assess=False):
		super(InterEnsembleLearningTransformer, self).__init__()
		self.assess = assess
		self.smooth_value = smooth_value
		self.num_classes = num_classes
		self.loss_alpha = loss_alpha
		self.cam = cam

		self.embeddings = Embeddings(config, img_size=img_size)
		self.encoder = IELTEncoder(config, update_warm, vote_perhead, dataset, cam, dsm,
		                           fix, total_num, assess)
		self.head = Linear(config.hidden_size, num_classes)
		self.softmax = Softmax(dim=-1)


	def forward(self, x, labels=None):
		test_mode = False if labels is not None else True
		x = self.embeddings(x)
		if self.assess:
			x, xc, assess_list = self.encoder(x, test_mode)
		else:
			x, xc = self.encoder(x, test_mode)

		if self.cam:
			complement_logits = self.head(xc)
			probability = self.softmax(complement_logits)
			weight = self.head.weight
			assist_logit = probability * (weight.sum(-1))
			part_logits = self.head(x) + assist_logit
		else:
			part_logits = self.head(x)

		if self.assess:
			return part_logits, assess_list

		elif test_mode:
			return part_logits

		else:
			if self.smooth_value == 0:
				loss_fct = CrossEntropyLoss()
			else:
				loss_fct = LabelSmoothing(self.smooth_value)

			if self.cam:
				loss_p = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
				loss_c = loss_fct(complement_logits.view(-1, self.num_classes), labels.view(-1))
				alpha = self.loss_alpha
				loss = (1 - alpha) * loss_p + alpha * loss_c
			else:
				loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
			return part_logits, loss

	def get_eval_data(self):
		return self.encoder.select_num

	def load_from(self, weights):
		with torch.no_grad():
			nn.init.zeros_(self.head.weight)
			nn.init.zeros_(self.head.bias)

			self.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
			self.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
			self.embeddings.cls_token.copy_(np2th(weights["cls"]))
			# self.encoder.patch_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
			# self.encoder.patch_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
			# self.encoder.clr_encoder.patch_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
			# self.encoder.clr_encoder.patch_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

			posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
			posemb_new = self.embeddings.position_embeddings
			if posemb.size() == posemb_new.size():
				self.embeddings.position_embeddings.copy_(posemb)
			else:
				ntok_new = posemb_new.size(1)

				posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
				ntok_new -= 1

				gs_old = int(np.sqrt(len(posemb_grid)))
				gs_new = int(np.sqrt(ntok_new))
				# print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
				posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

				zoom = (gs_new / gs_old, gs_new / gs_old, 1)
				posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
				posemb_grid = posemb_grid.reshape((1, gs_new * gs_new, -1))
				posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
				self.embeddings.position_embeddings.copy_(np2th(posemb))

			for bname, block in self.encoder.named_children():
				for uname, unit in block.named_children():
					if not bname.startswith('key') and not bname.startswith('clr'):
						if uname == '12':
							uname = '11'
						unit.load_from(weights, n_block=uname)


class MultiHeadVoting(nn.Module):
	def __init__(self, config, vote_perhead=24, fix=True):
		super(MultiHeadVoting, self).__init__()
		self.fix = fix
		self.num_heads = config.num_heads
		self.vote_perhead = vote_perhead

		if self.fix:
			self.kernel = torch.tensor([[1, 2, 1],
			                            [2, 4, 2],
			                            [1, 2, 1]], device='cuda').unsqueeze(0).unsqueeze(0).half()
			self.conv = F.conv2d
		else:
			self.conv = nn.Conv2d(1, 1, 3, 1, 1)

	def forward(self, x, select_num=None, last=False):
		B, patch_num = x.shape[0], x.shape[3] - 1
		select_num = self.vote_perhead if select_num is None else select_num
		count = torch.zeros((B, patch_num), dtype=torch.int, device='cuda').half()
		score = x[:, :, 0, 1:]
		_, select = torch.topk(score, self.vote_perhead, dim=-1)
		select = select.reshape(B, -1)

		for i, b in enumerate(select):
			count[i, :] += torch.bincount(b, minlength=patch_num)

		if not last:
			count = self.enhace_local(count)
			pass

		patch_value, patch_idx = torch.sort(count, dim=-1, descending=True)
		patch_idx += 1
		return patch_idx[:, :select_num], count

	def enhace_local(self, count):
		B, H = count.shape[0], math.ceil(math.sqrt(count.shape[1]))
		count = count.reshape(B, H, H)
		if self.fix:
			count = self.conv(count.unsqueeze(1), self.kernel, stride=1, padding=1).reshape(B, -1)
		else:
			count = self.conv(count.unsqueeze(1)).reshape(B, -1)
		return count


class CrossLayerRefinement(nn.Module):
	def __init__(self, config, clr_layer):
		super(CrossLayerRefinement, self).__init__()
		self.clr_layer = clr_layer
		self.clr_norm = LayerNorm(config.hidden_size, eps=1e-6)

	def forward(self, x, cls):
		out = [torch.stack(token) for token in x]
		out = torch.stack(out).squeeze(1)
		out = torch.cat((cls, out), dim=1)
		out, weights = self.clr_layer(out)
		out = self.clr_norm(out)
		return out, weights


class IELTEncoder(nn.Module):
	def __init__(self, config, update_warm=500, vote_perhead=24, dataset='cub',
	             cam=True, dsm=True, fix=True, total_num=126, assess=False):
		super(IELTEncoder, self).__init__()
		self.assess = assess
		self.warm_steps = update_warm
		self.layer = nn.ModuleList()
		self.layer_num = config.num_layers
		self.vote_perhead = vote_perhead
		self.dataset = dataset
		self.cam = cam
		self.dsm = dsm

		for _ in range(self.layer_num - 1):
			self.layer.append(Block(config, assess=self.assess))

		if self.dataset == 'dog' or self.dataset == 'nabrids':
			self.layer.append(Block(config, assess=self.assess))
			self.clr_layer = self.layer[-1]
			if self.cam:
				self.layer.append(Block(config, assess=self.assess))
				self.key_layer = self.layer[-1]
		else:
			self.clr_layer = Block(config)
			if self.cam:
				self.key_layer = Block(config)

		if self.cam:
			self.key_norm = LayerNorm(config.hidden_size, eps=1e-6)

		self.patch_select = MultiHeadVoting(config, self.vote_perhead, fix)

		self.total_num = total_num
		## for CUB and NABirds
		self.select_rate = torch.tensor([16, 14, 12, 10, 8, 6, 8, 10, 12, 14, 16], device='cuda') / self.total_num
		## for Others
		# self.select_rate = torch.ones(self.layer_num-1,device='cuda')/(self.layer_num-1)

		self.select_num = self.select_rate * self.total_num
		self.clr_encoder = CrossLayerRefinement(config, self.clr_layer)
		self.count = 0

	def forward(self, hidden_states, test_mode=False):
		if not test_mode:
			self.count += 1
		B, N, C = hidden_states.shape
		complements = [[] for i in range(B)]
		class_token_list = []
		if self.assess:
			layer_weights = []
			layer_selected = []
			layer_score = []
		else:
			pass

		for t in range(self.layer_num - 1):
			layer = self.layer[t]
			select_num = torch.round(self.select_num[t]).int()
			hidden_states, weights = layer(hidden_states)
			select_idx, select_score = self.patch_select(weights, select_num)
			for i in range(B):
				complements[i].extend(hidden_states[i, select_idx[i, :]])
			class_token_list.append(hidden_states[:, 0].unsqueeze(1))
			if self.assess:
				layer_weights.append(weights)
				layer_score.append(select_score)
				layer_selected.extend(select_idx)
		cls_token = hidden_states[:, 0].unsqueeze(1)

		clr, weights = self.clr_encoder(complements, cls_token)
		sort_idx, _ = self.patch_select(weights, select_num=24, last=True)

		if not test_mode and self.count >= self.warm_steps and self.dsm:
			# if not test_mode and self.count >= 500 and self.dsm:
			layer_count = self.count_patch(sort_idx)
			self.update_layer_select(layer_count)

		class_token_list = torch.cat(class_token_list, dim=1)

		if not self.cam:
			return clr[:, 0], None
		else:
			out = []
			for i in range(B):
				out.append(clr[i, sort_idx[i, :]])
			out = torch.stack(out).squeeze(1)
			out = torch.cat((cls_token, out), dim=1)
			out, _ = self.key_layer(out)
			key = self.key_norm(out)

		if self.assess:
			assess_list = [layer_weights, layer_selected, layer_score, sort_idx]
			return key[:, 0], clr[:, 0], assess_list
		else:

			# fused = torch.cat((class_token_list, clr[:, 0].unsqueeze(1)), dim=1)
			# clr[:, 0] = fused.mean(1)
			return key[:, 0], clr[:, 0]

	def update_layer_select(self, layer_count):
		alpha = 1e-3  # if self.dataset != 'dog' and self.dataset == 'nabirds' else 1e-4
		new_rate = layer_count / layer_count.sum()

		self.select_rate = self.select_rate * (1 - alpha) + alpha * new_rate
		self.select_rate /= self.select_rate.sum()
		self.select_num = self.select_rate * self.total_num

	def count_patch(self, sort_idx):
		layer_count = torch.cumsum(self.select_num, dim=-1)
		sort_idx = (sort_idx - 1).reshape(-1)
		for i in range(self.layer_num - 1):
			mask = (sort_idx < layer_count[i])
			layer_count[i] = mask.sum()
		cum_count = torch.cat((torch.tensor([0], device='cuda'), layer_count[:-1]))
		layer_count -= cum_count
		return layer_count.int()

	## Old Implementation
	# layer_count = torch.zeros(self.layer_num, device='cuda').int()
	# sort_idx = (sort_idx - 1).reshape(-1)
	# sorted, _ = torch.sort(sort_idx)
	# for j in range(self.layer_num):
	# 	if j == (self.layer_num - 1):
	# 		layer_count[j] = len(sorted)
	# 		break
	# 	a = self.select_num[:j + 1].sum()
	# 	for i, val in enumerate(sorted):
	# 		flag = True
	# 		if flag and val > a:
	# 			layer_count[j] += i
	# 			sorted = sorted[i:]
	# 			flag = False
	# 		if not flag:
	# 			break
	# return layer_count

!pip install ml-collections
if __name__ == '__main__':
	start = time.time()
	config = get_b16_config()
	# com = clrEncoder(config,)
	# com.to(device='cuda')
	net = InterEnsembleLearningTransformer(config).cuda()
	# hidden_state = torch.arange(400*768).reshape(2,200,768)/1.0
	x = torch.rand(1, 3, 448, 448, device='cuda')
	y = net(x)
	print(y.shape)

torch.Size([4, 2])


# Audio with this ensemble learning

Conformer.py from Tuan

In [1]:
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_same_padding(kernel_size):
    pad = kernel_size // 2
    return (pad, pad - (kernel_size + 1) % 2)

# helper classes

class Swish(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()

class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.padding = padding
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

    def forward(self, x):
        x = F.pad(x, self.padding)
        return self.conv(x)

# attention, feedforward, and conv module

class Scale(nn.Module):
    def __init__(self, scale, fn):
        super().__init__()
        self.fn = fn
        self.scale = scale

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

class Attention(nn.Module):
    # Head Token attention: https://arxiv.org/pdf/2210.05958.pdf
    def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.):
        super().__init__()
        self.num_heads = heads
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5

        self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias)

        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(inner_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.act = nn.GELU()
        self.ht_proj = nn.Linear(dim_head, dim,bias=True)
        self.ht_norm = nn.LayerNorm(dim_head)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim))

    def forward(self, x, mask=None):
        B, N, C = x.shape

        # head token
        head_pos = self.pos_embed.expand(x.shape[0], -1, -1)
        x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        x_ = x_.mean(dim=2)  # now the shape is [B, h, 1, d//h]
        x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads)
        x_ = self.act(self.ht_norm(x_)).flatten(2)
        x_ = x_ + head_pos
        x = torch.cat([x, x_], dim=1)

        # normal mhsa
        qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        # attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C)
        x = self.proj(x)

        # merge head tokens into cls token
        cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1)
        cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True)
        x = torch.cat([cls, patch], dim=1)

        x = self.proj_drop(x)

        return x, attn


class FeedForward(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            Swish(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class ConformerConvModule(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        super().__init__()

        inner_dim = dim * expansion_factor
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim * 2, 1),
            GLU(dim=1),
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
            nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
            Swish(),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# Conformer Block

class ConformerBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
        self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
        self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
        self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

        self.attn = PreNorm(dim, self.attn)
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x, mask = None):
        x = self.ff1(x) + x
        attn_x, attn_weight = self.attn(x, mask = mask)
        x = attn_x + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x, attn_weight

# Conformer

class Conformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.dim = dim
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(ConformerBlock(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                conv_causal = conv_causal

            ))

    def forward(self, x):

        for block in self.layers:
            x = block(x)

        return x

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AudioInterEnsembleLearningTransformer(nn.Module):
    def __init__(self, config, num_classes=2, emb_size=128, timesteps=100, vote_perhead=24, total_num=126, assess=False):
        super(AudioInterEnsembleLearningTransformer, self).__init__()
        self.assess = assess
        self.num_classes = num_classes
        self.vote_perhead = vote_perhead
        self.total_num = total_num


        self.positional_emb = nn.Parameter(self.sinusoidal_embedding(timesteps, emb_size), requires_grad=False)


        self.encoder = IELTEncoder(config, vote_perhead=vote_perhead, total_num=total_num, assess=assess)


        self.head = nn.Linear(config.hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, labels=None):
        test_mode = False if labels is not None else True

        print("Input shape:", x.shape)

        self.positional_emb = self.positional_emb.to(x.device)

        print("Positional embedding shape:", self.positional_emb[:x.size(1), :].unsqueeze(0).shape)

        x = x + self.positional_emb[:x.size(1), :].unsqueeze(0)

        print("Shape after positional embedding:", x.shape)

        if self.assess:
            x, xc, assess_list = self.encoder(x, test_mode)
        else:
            x, xc = self.encoder(x, test_mode)

        print("Encoder output shape:", xc.shape)

        complement_logits = self.head(xc)

        print("Shape after linear (head) layer:", complement_logits.shape)

        probability = self.softmax(complement_logits)
        weight = self.head.weight
        assist_logit = probability * weight.sum(-1)

        print("Shape of assist_logit:", assist_logit.shape)

        part_logits = self.head(x) + assist_logit

        print("Final part_logits shape:", part_logits.shape)

        if self.assess:
            return part_logits, assess_list
        elif test_mode:
            return part_logits
        else:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(part_logits.view(-1, self.num_classes), labels.view(-1))
            return part_logits, loss


    def sinusoidal_embedding(self, n_position, d_model):
        # Standard sinusoidal embedding used in transformers
        position = torch.arange(0, n_position, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(n_position, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe


class IELTEncoder(nn.Module):
    def __init__(self, config, vote_perhead=24, total_num=126, assess=False):
        super(IELTEncoder, self).__init__()
        self.assess = assess
        self.layer_num = config.num_layers
        self.vote_perhead = vote_perhead

        # I Replaced Block with ConformerBlock
        self.layer = nn.ModuleList([
            ConformerBlock(dim=config.hidden_size, dim_head=config.hidden_size // config.num_heads, heads=config.num_heads)
            for _ in range(self.layer_num - 1)
        ])

        self.clr_layer = ConformerBlock(dim=config.hidden_size, dim_head=config.hidden_size // config.num_heads, heads=config.num_heads)
        self.clr_encoder = CrossLayerRefinement(config, self.clr_layer)
        self.patch_select = MultiHeadVoting(config, self.vote_perhead)

    def forward(self, hidden_states, test_mode=False):
        B, N, C = hidden_states.shape
        complements = [[] for _ in range(B)]
        class_token_list = []

        for t in range(self.layer_num - 1):
            layer = self.layer[t]
            hidden_states, weights = layer(hidden_states)  # Using ConformerBlock
            select_idx, select_score = self.patch_select(weights)
            for i in range(B):
                complements[i].extend(hidden_states[i, select_idx[i, :]])
            class_token_list.append(hidden_states[:, 0].unsqueeze(1))

        cls_token = hidden_states[:, 0].unsqueeze(1)
        clr, weights = self.clr_encoder(complements, cls_token)

        return clr[:, 0], weights


class MultiHeadVoting(nn.Module):
    def __init__(self, config, vote_perhead=24, fix=True):
        super(MultiHeadVoting, self).__init__()
        self.fix = fix
        self.num_heads = config.num_heads
        self.vote_perhead = vote_perhead

        if self.fix:
            self.kernel = torch.tensor([[1, 2, 1],
                                        [2, 4, 2],
                                        [1, 2, 1]], device='cuda').unsqueeze(0).unsqueeze(0).half()
            self.conv = F.conv2d
        else:
            self.conv = nn.Conv2d(1, 1, 3, 1, 1)

    def forward(self, x, select_num=None, last=False):
        B, patch_num = x.shape[0], x.shape[3] - 1
        select_num = self.vote_perhead if select_num is None else select_num
        count = torch.zeros((B, patch_num), dtype=torch.int, device='cuda').half()
        score = x[:, :, 0, 1:]
        _, select = torch.topk(score, self.vote_perhead, dim=-1)
        select = select.reshape(B, -1)

        for i, b in enumerate(select):
            # b = b.cuda()
            count[i, :] += torch.bincount(b, minlength=patch_num)

        if not last:
            count = self.enhance_local(count)

        patch_value, patch_idx = torch.sort(count, dim=-1, descending=True)
        patch_idx += 1
        return patch_idx[:, :select_num], count

    def enhance_local(self, count):
        # B, H = count.shape[0], math.ceil(math.sqrt(count.shape[1]))
        # count = count.reshape(B, H, H)
        B = count.shape[0]
        H = int(math.sqrt(count.shape[1]))
        if self.fix:
            count = self.conv(count.unsqueeze(1).unsqueeze(1), self.kernel, stride=1, padding=1).reshape(B, -1)
        else:
            count = self.conv(count.unsqueeze(1)).reshape(B, -1)
        return count


class CrossLayerRefinement(nn.Module):
    def __init__(self, config, clr_layer):
        super(CrossLayerRefinement, self).__init__()
        self.clr_layer = clr_layer
        self.clr_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)

    def forward(self, complements, cls_token):
        out = [torch.stack(token) for token in complements]
        out = torch.stack(out).squeeze(1)
        out = torch.cat((cls_token, out), dim=1)
        out, weights = self.clr_layer(out)
        out = self.clr_norm(out)
        return out, weights


In [3]:

batch_size = 1
timesteps = 100
emb_size = 128
num_classes = 2

input_tensor = torch.rand(batch_size, timesteps, emb_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_tensor = input_tensor.to(device)

class Config:
    hidden_size = emb_size
    num_layers = 6
    num_heads = 8

config = Config()


model = AudioInterEnsembleLearningTransformer(config, num_classes=num_classes, emb_size=emb_size, timesteps=timesteps)
model.to(device)

output, attn_weights = model(input_tensor)

print("Output shape:", output.shape)
print("Output:", output)


Input shape: torch.Size([1, 100, 128])
Positional embedding shape: torch.Size([1, 100, 128])
Shape after positional embedding: torch.Size([1, 100, 128])


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
