In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.pyplot as plt

In [2]:
class AudioEnhancer(nn.Module):
    def __init__(self, num_transformer_layers=2, num_heads=8, cnn_filters=[32, 64, 128, 256]):
        super(AudioEnhancer, self).__init__()
        
        # CNN Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(2, cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[0], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[1], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(cnn_filters[2], cnn_filters[3], kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Transformer
        self.transformer = TransformerEncoder(
            TransformerEncoderLayer(
                d_model=cnn_filters[-1], 
                nhead=num_heads, 
                dim_feedforward=512, 
                activation='relu', 
                batch_first=True
            ),
            num_layers=num_transformer_layers
        )
        
        # CNN Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(cnn_filters[3], cnn_filters[2], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[2], cnn_filters[1], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[1], cnn_filters[0], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(cnn_filters[0], 2, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Input: [batch_size, 2, 4800]
        
        # CNN Encoder
        x = self.encoder(x)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # Permute for Transformer
        x = x.permute(0, 2, 1)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = self.transformer(x)  # Shape: [batch_size, 4800, cnn_filters[-1]]
        x = x.permute(0, 2, 1)  # Shape: [batch_size, cnn_filters[-1], 4800]
        
        # CNN Decoder
        x = self.decoder(x)  # Shape: [batch_size, 2, 4800]
        
        return x



class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.mse_loss = nn.MSELoss()
    
    def forward(self, pred, target):
        # Compute perceptual features
        pred_features = self.feature_extractor(pred)
        target_features = self.feature_extractor(target)
        
        # Perceptual loss
        perceptual_loss = self.mse_loss(pred_features, target_features)
        
        # Reconstruction loss
        reconstruction_loss = self.mse_loss(pred, target)
        
        return perceptual_loss + reconstruction_loss


# Dummy feature extractor for perceptual loss
class DummyFeatureExtractor(nn.Module):
    def __init__(self):
        super(DummyFeatureExtractor, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(2, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.features(x)

In [4]:
import torch
model_save_path = '/home/j597s263/scratch/j597s263/Models/Audio/Exp_4.mod'
model = torch.load(model_save_path, weights_only=False)
model.eval()

print("Model Loaded")

Model Loaded
