In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import os
import sys

import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import simplejpeg
import torch
import torch.nn as nn
import torchvision as tv
from ay2.torch.deepfake_detection import DeepfakeAudioClassification
from IPython.display import Audio, display
from PIL import Image

In [3]:
from model import AudioCLIP

In [None]:
from .model import AudioCLIP

## 测试

In [4]:
aclp = AudioCLIP(
    pretrained="/home/ay/data/DATA/0-model_weights/AudioClip/AudioCLIP-Full-Training.pt"
)

audio = torch.randn(2, 48000)
((audio_features, _, _), _), _ = aclp(audio=audio)
audio_features, audio_features.shape

(tensor([[ 0.0030, -0.0025,  0.0486,  ...,  0.0038,  0.0334,  0.0075],
         [-0.0029, -0.0032,  0.0526,  ...,  0.0043,  0.0379,  0.0111]],
        grad_fn=<DivBackward0>),
 torch.Size([2, 1024]))

## Lit model

In [None]:
class AudioClip(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.model = AudioCLIP(
            pretrained="/home/ay/data/DATA/0-model_weights/AudioClip/AudioCLIP-Full-Training.pt"
        )
        self.proj = nn.Linear(1024, 1)

    def forward(self, x):
        ((audio_features, _, _), _), _ = self.model(x)
        y = self.proj(audio_features)
        return y

    def extract_feature(self, x):
        ((audio_features, _, _), _), _ = self.model(x)
        return audio_features

    def make_prediction(self, audio_features):
        y = self.proj(audio_features)
        return y

In [None]:
class AudioClip_lit(DeepfakeAudioClassification):
    def __init__(self, backend="linear", **kwargs):
        super().__init__()
        self.model = AudioClip()
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.save_hyperparameters()

    def calcuate_loss(self, batch_res, batch):
        label = batch["label"]
        loss = self.loss_fn(batch_res["logit"], label.type(torch.float32))
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.model.parameters(), lr=0.0001, weight_decay=0.0001
        )
        return [optimizer]
