In [1]:
import os
import dill
import pandas as pd
import numpy as np
import torch
import torchaudio
import torchvision.models as models
import torch.nn as nn
import clip
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
from pathlib import Path
import torch.optim as optim
import torchaudio.transforms as T
import torchvision.transforms as VT
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from torch.utils.data import Dataset, DataLoader

In [2]:
def load_audio(filename):
    if os.path.getsize(filename) == 0:
        print(f"File {filename} is empty.")
        return None, None
    else:
        return torchaudio.load(filename)
    
def adjust_audio_length(waveform, target_length):
    current_length = waveform.shape[1]
    if current_length > target_length:
        # 截断长音频
        waveform = waveform[:, :target_length]
    elif current_length < target_length:
        # 填充短音频
        padding_size = target_length - current_length
        padding = torch.zeros((waveform.shape[0], padding_size))
        waveform = torch.cat((waveform, padding), 1)
    return waveform

class loadDataset(Dataset):
    def __init__(self, base_path, meta_path):
        self.base_path = base_path
        self.df = pd.read_csv(meta_path)
        self.max_length=2206

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        filename = os.path.join(self.base_path, row['YTID'])
        label = row['positive_labels']
        waveform, sample_rate = load_audio(filename)
        
        if(waveform==None):
            return None
        else:
            waveform = adjust_audio_length(waveform, self.max_length)
        
        spectrogram = T.MelSpectrogram(sample_rate=sample_rate, n_mels=64, n_fft=1024)(waveform)

        
        if spectrogram.shape[0] > 1:
            spectrogram = spectrogram[0, :, :].unsqueeze(0)

        mean = [0.485]
        std = [0.229]
        spectrogram = VT.Normalize(mean=mean, std=std)(spectrogram)
        
        # current_length = spectrogram.shape[-1]
        # # print(current_length)
        # if current_length < self.max_length:
        #     # 需要填充的长度
        #     padding_length = self.max_length - current_length
        #     # 在最后一个维度上填充
        #     spectrogram = torch.nn.functional.pad(spectrogram, (0, padding_length))

        return spectrogram, label

In [3]:
base_path = 'train_wav/'
meta_path = '/train_test/labels.csv'

dataset = loadDataset(base_path, meta_path)
dataloader = DataLoader(dataset, batch_size=128,shuffle=True)

In [4]:
class AudioTextLoss(nn.Module):
    def __init__(self):
        super(AudioTextLoss, self).__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.loss_audio = nn.CrossEntropyLoss()
        self.loss_text = nn.CrossEntropyLoss()

    def forward(self, audio_features, text_features):
        # Normalize features
        audio_features = audio_features / audio_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Calculate cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_audio = logit_scale * audio_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ audio_features.t()

        # Calculate batch size for ground truth
        batch_size = audio_features.shape[0]
        ground_truth = torch.arange(batch_size, dtype=torch.long, device=audio_features.device)

        # Compute loss as the average of audio-to-text and text-to-audio losses
        return (
            self.loss_audio(logits_per_audio, ground_truth)
            + self.loss_text(logits_per_text, ground_truth)
        ) / 2


In [5]:
class MLPLayers(nn.Module):
    def __init__(self, input_features, output_features, dropout_p=0.1):
        super(MLPLayers, self).__init__()
        self.nonlin = nn.ReLU()
        self.sequential = nn.Sequential(
            nn.Linear(input_features, output_features),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(output_features, output_features)
        )

    def forward(self, x):
        return self.sequential(self.nonlin(x))

resnet = models.resnet18(pretrained=False)

resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

num_ftrs = resnet.fc.in_features
resnet.fc = MLPLayers(input_features=num_ftrs, output_features=512)



In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, _ = clip.load("ViT-B/32", device=device)
clip_model = clip_model.eval()
loss_fn=AudioTextLoss()
resnet.to(device)
clip_model.to(device)

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [7]:
def train(model, data_loader, optimizer, clip_model, epochs=30):
    model.train()
    for epoch in range(epochs):
        loss_epoch=0
        for spectrograms,labels in tqdm(data_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            # print(labels)
            spectrograms = spectrograms.to(device)
            optimizer.zero_grad()
            audio_features = model(spectrograms)
            labels=["This is a sound of "+label for label in labels]
            text_tokens = clip.tokenize(labels).to(device)
            with torch.no_grad():
                text_features = clip_model.encode_text(text_tokens)
            text_features = text_features.to(dtype=torch.float)
            loss = loss_fn(audio_features, text_features)
            loss_epoch += loss
            loss.backward()
            optimizer.step()
            scheduler.step()
        print(f"Epoch: {epoch+1}, loss: {loss_epoch/len(dataset)}")

optimizer = optim.Adam(resnet.parameters(), lr=0.1)
scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.01)

train(resnet, dataloader, optimizer, clip_model, epochs=30)

Epoch 1/30: 100%|██████████| 152/152 [08:51<00:00,  3.50s/it]


Epoch: 1, loss: 0.03742595762014389


Epoch 2/30: 100%|██████████| 152/152 [02:56<00:00,  1.16s/it]


Epoch: 2, loss: 0.03684844449162483


Epoch 3/30: 100%|██████████| 152/152 [03:02<00:00,  1.20s/it]


Epoch: 3, loss: 0.03641180694103241


Epoch 4/30: 100%|██████████| 152/152 [08:38<00:00,  3.41s/it]


Epoch: 4, loss: 0.03615729510784149


Epoch 5/30: 100%|██████████| 152/152 [11:14<00:00,  4.43s/it]


Epoch: 5, loss: 0.03592006117105484


Epoch 6/30: 100%|██████████| 152/152 [03:05<00:00,  1.22s/it]


Epoch: 6, loss: 0.035785023123025894


Epoch 7/30: 100%|██████████| 152/152 [02:57<00:00,  1.17s/it]


Epoch: 7, loss: 0.035687997937202454


Epoch 8/30: 100%|██████████| 152/152 [02:59<00:00,  1.18s/it]


Epoch: 8, loss: 0.035393379628658295


Epoch 9/30: 100%|██████████| 152/152 [03:00<00:00,  1.19s/it]


Epoch: 9, loss: 0.035214539617300034


Epoch 10/30: 100%|██████████| 152/152 [03:02<00:00,  1.20s/it]


Epoch: 10, loss: 0.03501944988965988


Epoch 11/30: 100%|██████████| 152/152 [03:02<00:00,  1.20s/it]


Epoch: 11, loss: 0.03494879603385925


Epoch 12/30: 100%|██████████| 152/152 [02:57<00:00,  1.17s/it]


Epoch: 12, loss: 0.034742362797260284


Epoch 13/30: 100%|██████████| 152/152 [02:59<00:00,  1.18s/it]


Epoch: 13, loss: 0.034620095044374466


Epoch 14/30: 100%|██████████| 152/152 [03:05<00:00,  1.22s/it]


Epoch: 14, loss: 0.03451976180076599


Epoch 15/30: 100%|██████████| 152/152 [02:59<00:00,  1.18s/it]


Epoch: 15, loss: 0.034322354942560196


Epoch 16/30: 100%|██████████| 152/152 [02:57<00:00,  1.17s/it]


Epoch: 16, loss: 0.03424769639968872


Epoch 17/30: 100%|██████████| 152/152 [02:54<00:00,  1.15s/it]


Epoch: 17, loss: 0.03410407900810242


Epoch 18/30: 100%|██████████| 152/152 [02:57<00:00,  1.17s/it]


Epoch: 18, loss: 0.03396962210536003


Epoch 19/30: 100%|██████████| 152/152 [02:57<00:00,  1.16s/it]


Epoch: 19, loss: 0.03384880721569061


Epoch 20/30: 100%|██████████| 152/152 [02:58<00:00,  1.18s/it]


Epoch: 20, loss: 0.03368601202964783


Epoch 21/30: 100%|██████████| 152/152 [03:02<00:00,  1.20s/it]


Epoch: 21, loss: 0.03356330469250679


Epoch 22/30: 100%|██████████| 152/152 [03:01<00:00,  1.20s/it]


Epoch: 22, loss: 0.03345692530274391


Epoch 23/30: 100%|██████████| 152/152 [03:01<00:00,  1.20s/it]


Epoch: 23, loss: 0.033281952142715454


Epoch 24/30: 100%|██████████| 152/152 [02:59<00:00,  1.18s/it]


Epoch: 24, loss: 0.033145759254693985


Epoch 25/30: 100%|██████████| 152/152 [03:01<00:00,  1.20s/it]


Epoch: 25, loss: 0.03301706165075302


Epoch 26/30: 100%|██████████| 152/152 [03:00<00:00,  1.19s/it]


Epoch: 26, loss: 0.0328570231795311


Epoch 28/30: 100%|██████████| 152/152 [02:58<00:00,  1.18s/it]


Epoch: 28, loss: 0.03252078965306282


Epoch 29/30: 100%|██████████| 152/152 [03:02<00:00,  1.20s/it]


Epoch: 29, loss: 0.0323791578412056


Epoch 30/30: 100%|██████████| 152/152 [02:57<00:00,  1.17s/it]

Epoch: 30, loss: 0.03220837190747261





In [8]:
model_save_path = "new_AudiosetResnet"
os.makedirs(model_save_path, exist_ok=True)

torch.save(resnet, os.path.join(model_save_path, "resnet_model_next.pth"), pickle_module=dill)